コード例 #1
0
def predict(opts=None):
    if opts is None:
        sys.argv.extend(["evaluation.predict=true"])
    else:
        opts.extend(["evaluation.predict=true"])

    run(predict=True)
コード例 #2
0
ファイル: test_env.py プロジェクト: vishalbelsare/pythia
    def _test_user_import_e2e(self, extra_opts=None):
        if extra_opts is None:
            extra_opts = []

        MAX_UPDATES = 50
        user_dir = self._get_user_dir()
        with make_temp_dir() as temp_dir:
            opts = [
                "model=simple",
                "run_type=train_val_test",
                "dataset=always_one",
                "config=configs/experiment.yaml",
                f"env.user_dir={user_dir}",
                "training.seed=1",
                "training.num_workers=3",
                f"training.max_updates={MAX_UPDATES}",
                f"env.save_dir={temp_dir}",
            ]
            opts = opts + extra_opts
            out = io.StringIO()
            with contextlib.redirect_stdout(out):
                run(opts)
            train_log = os.path.join(temp_dir, "train.log")
            log_line = search_log(
                train_log,
                search_condition=[
                    lambda x: x["progress"] == f"{MAX_UPDATES}/{MAX_UPDATES}",
                    lambda x: "best_val/always_one/accuracy" in x,
                ],
            )
            self.assertEqual(float(log_line["val/always_one/accuracy"]), 1)

            log_line = search_log(
                train_log,
                search_condition=[
                    lambda x: x["progress"] == f"{MAX_UPDATES}/{MAX_UPDATES}",
                    lambda x: "test/always_one/accuracy" in x,
                ],
            )
            self.assertEqual(float(log_line["test/always_one/accuracy"]), 1)
コード例 #3
0
def predict():
    sys.argv.extend(["evaluation.predict=true"])
    run(predict=True)