Beispiel #1
0
def main():
    parser = argparse.ArgumentParser()
    _.parse_file(__file__)
    hpargparse.bind(parser, _)
    parser.parse_args()

    func()
Beispiel #2
0
 def _make(self, fpath):
     fpath = str(fpath)
     hp_mgr = hpman.HyperParameterManager("_")
     parser = argparse.ArgumentParser()
     parser.add_argument(dest="predefined_arg")
     hp_mgr.parse_file(fpath)
     hpargparse.bind(parser, hp_mgr)
     return parser, hp_mgr
Beispiel #3
0
        def do_test(assert_type, regex, **bind_kwargs):
            hp_mgr, parser = self._make_pair()
            hpargparse.bind(parser, hp_mgr, **bind_kwargs)

            assertion = {True: self.assertRegex, False: self.assertNotRegex}[
                assert_type
            ]
            assertion(parser.format_help(), regex)
Beispiel #4
0
    def test_show_default_value_in_help_message(self):
        hp_mgr, parser = self._make_pair()
        hp_mgr.parse_source("_('b', True)")
        hp_mgr.parse_source("_('c', 'deadbeef')")
        hpargparse.bind(parser, hp_mgr)

        h = parser.format_help()
        self.assertRegex(h, "default: 1")
        self.assertRegex(h, "default: True")
        self.assertRegex(h, "default: deadbeef")
Beispiel #5
0
    def test_bind_serial_format(self):
        hp_mgr, parser = self._make_pair()
        hpargparse.bind(parser, hp_mgr, serial_format="pickle")

        with auto_cleanup_temp_dir() as d:
            path = d / "a.pkl"
            parser.parse_args(["--hp-save", str(path)])

            with open(str(path), "rb") as f:
                x = pickle.load(f)

            self.assertEqual(x["a"], 1)
    def test_string_as_default_should_not_be_saved(self):
        hp_mgr = hpman.HyperParameterManager("_")
        hp_mgr.parse_source('_("a", "hello")')

        parser = argparse.ArgumentParser(
            formatter_class=argparse.ArgumentDefaultsHelpFormatter
        )
        hpargparse.bind(parser, hp_mgr)
        args = parser.parse_args(["--a", "world"])
        self.assertNotIsInstance(
            hp_mgr.get_value("a"), hpargparse.hputils.StringAsDefault
        )
    def test_subparser(self):
        hp_mgr = hpman.HyperParameterManager("_")
        hp_mgr.parse_source('_("a", 1)')

        parser = argparse.ArgumentParser(
            formatter_class=argparse.ArgumentDefaultsHelpFormatter
        )
        subparsers = parser.add_subparsers()
        p = subparsers.add_parser("sub")
        hpargparse.bind(p, hp_mgr)
        parser.parse_args(["sub", "--a", str(2)])

        self.assertEqual(2, hp_mgr.get_values()["a"])
Beispiel #8
0
def main():
    parser = argparse.ArgumentParser()

    # ... do whatever you want
    parser.add_argument(dest="predefined_arg")

    # analyze everything in this directory
    _.parse_file(BASE_DIR)  # <-- IMPORTANT

    # bind will monkey_patch parser.parse_args to do its job
    hpargparse.bind(parser, _)  # <-- IMPORTANT

    # parse args and set the values
    args = parser.parse_args()

    # ... do whatever you want next
    import lib

    print("a = {}".format(_.get_value("a")))
    print("b = {}".format(_.get_value("b")))
    print("lib.add() = {}".format(lib.add()))
    print("lib.mult() = {}".format(lib.mult()))
Beispiel #9
0
 def test_bind_action_prefix(self):
     hp_mgr, parser = self._make_pair()
     hpargparse.bind(parser, hp_mgr, action_prefix="hahaha")
     self.assertRegex(parser.format_help(), r"--hahaha-")
Beispiel #10
0
def main():
    parser = argparse.ArgumentParser()
    _.parse_file(BASE_DIR)
    hpargparse.bind(parser, _)
    parser.parse_args()  # we need not to use args

    # print all hyperparameters
    print("-" * 10 + " Hyperparameters " + "-" * 10)
    print(yaml.dump(_.get_values()))

    optimizer_cls = {
        "adam": optim.Adam,
        "sgd": functools.partial(optim.SGD, momentum=0.9),
    }[_("optimizer", "adam")  # <-- hyperparameter
      ]

    import model

    net = model.get_model()
    if torch.cuda.is_available():
        net.cuda()

    optimizer = optimizer_cls(
        net.parameters(),
        lr=_("learning_rate", 1e-3),  # <-- hyperparameter
        weight_decay=_("weight_decay", 1e-5),  # <-- hyperparameter
    )

    import dataset

    train_ds = dataset.get_data_and_labels("train")
    test_ds = dataset.get_data_and_labels("test")
    if torch.cuda.is_available():
        # since mnist is a small dataset, we store the test dataset all in the
        # gpu memory
        test_ds = {k: v.cuda() for k, v in test_ds.items()}

    rng = np.random.RandomState(_("seed", 42))  # <-- hyperparameter

    for epoch in range(_("num_epochs", 30)):  # <-- hyperparameter
        net.train()
        tq = tqdm(
            enumerate(
                dataset.iter_dataset_batch(
                    rng,
                    train_ds,
                    _("batch_size", 256),  # <-- hyperparameter
                    cuda=torch.cuda.is_available(),
                )))
        for step, minibatch in tq:
            optimizer.zero_grad()

            Y_pred = net(minibatch["data"])
            loss = model.compute_loss(Y_pred, minibatch["labels"])

            loss.backward()
            optimizer.step()

            metrics = model.compute_metrics(Y_pred, minibatch["labels"])
            metrics["loss"] = loss.detach().cpu().numpy()
            tq.desc = "e:{} s:{} {}".format(
                epoch,
                step,
                " ".join([
                    "{}:{}".format(k, v) for k, v in sorted(metrics.items())
                ]),
            )

        net.eval()

        # since mnist is a small dataset, we predict all values at once.
        Y_pred = net(test_ds["data"])
        metrics = model.compute_metrics(Y_pred, test_ds["labels"])
        print("eval: {}".format(" ".join(
            ["{}:{}".format(k, v) for k, v in sorted(metrics.items())])))