def predict(opts=None): if opts is None: sys.argv.extend(["evaluation.predict=true"]) else: opts.extend(["evaluation.predict=true"]) run(predict=True)
def _test_user_import_e2e(self, extra_opts=None): if extra_opts is None: extra_opts = [] MAX_UPDATES = 50 user_dir = self._get_user_dir() with make_temp_dir() as temp_dir: opts = [ "model=simple", "run_type=train_val_test", "dataset=always_one", "config=configs/experiment.yaml", f"env.user_dir={user_dir}", "training.seed=1", "training.num_workers=3", f"training.max_updates={MAX_UPDATES}", f"env.save_dir={temp_dir}", ] opts = opts + extra_opts out = io.StringIO() with contextlib.redirect_stdout(out): run(opts) train_log = os.path.join(temp_dir, "train.log") log_line = search_log( train_log, search_condition=[ lambda x: x["progress"] == f"{MAX_UPDATES}/{MAX_UPDATES}", lambda x: "best_val/always_one/accuracy" in x, ], ) self.assertEqual(float(log_line["val/always_one/accuracy"]), 1) log_line = search_log( train_log, search_condition=[ lambda x: x["progress"] == f"{MAX_UPDATES}/{MAX_UPDATES}", lambda x: "test/always_one/accuracy" in x, ], ) self.assertEqual(float(log_line["test/always_one/accuracy"]), 1)
def predict(): sys.argv.extend(["evaluation.predict=true"]) run(predict=True)