示例#1
0
    def test18_retrain_weight_consolidation(self):
        shutil.rmtree(OUTPUT_DIR, ignore_errors=True)

        args = [
            "--data-dir",
            DATA_BIN,
            "--output-dir",
            OUTPUT_DIR,
            "--model-dir",
            MODEL_DIR,
            "--tags",
            "CLI_tests",
            "--lr",
            "0.3",
            "--epochs",
            "5",
            "--log-every",
            "10",
            "--new-classes",
            "IN:GET_LOCATION,IN:GET_LOCATION_HOME,SL:POINT_ON_MAP,SL:CATEGORY_LOCATION",
            "--batch-size",
            "16",
            "--weight-consolidation",
            "0.1",
        ]

        args = retrain.parse_args(args)
        retrain.main(args)
示例#2
0
    def test16_retrain_average(self):
        shutil.rmtree(OUTPUT_DIR, ignore_errors=True)

        args = [
            "--data-dir",
            DATA_BIN,
            "--output-dir",
            OUTPUT_DIR,
            "--model-dir",
            MODEL_DIR,
            "--tags",
            "CLI_tests",
            "--lr",
            "0.3",
            "--epochs",
            "5",
            "--log-every",
            "10",
            "--new-classes",
            "IN:GET_LOCATION,IN:GET_LOCATION_HOME,SL:POINT_ON_MAP,SL:CATEGORY_LOCATION",
            "--batch-size",
            "16",
            "--no-lr-scheduler",
            "--average-checkpoints",
            "--new-model-weight",
            "1",
        ]

        args = retrain.parse_args(args)
        retrain.main(args)
示例#3
0
    def test06_retrain_old_data001_merge(self):
        shutil.rmtree(OUTPUT_DIR, ignore_errors=True)

        args = [
            "--data-dir",
            DATA_BIN,
            "--output-dir",
            OUTPUT_DIR,
            "--model-dir",
            MODEL_DIR,
            "--tags",
            "CLI_tests",
            "--epochs",
            "1",
            "--new-classes",
            "IN:GET_LOCATION,IN:GET_LOCATION_HOME,SL:POINT_ON_MAP,SL:CATEGORY_LOCATION",
            "--old-data-amount",
            "0.01",
            "--early-stopping",
            "16",
            "--old-data-sampling-method",
            "merge_subset",
        ]

        args = retrain.parse_args(args)
        retrain.main(args)
示例#4
0
    def test14_retrain_limit_iters(self):
        shutil.rmtree(OUTPUT_DIR, ignore_errors=True)

        args = [
            "--data-dir",
            DATA_BIN,
            "--output-dir",
            OUTPUT_DIR,
            "--model-dir",
            MODEL_DIR,
            "--tags",
            "CLI_tests",
            "--epochs",
            "1",
            "--min-steps",
            "3",
            "--max-steps",
            "10",
            "--new-classes",
            "IN:GET_LOCATION,IN:GET_LOCATION_HOME,SL:POINT_ON_MAP,SL:CATEGORY_LOCATION",
            "--label-smoothing",
            "0.18",
        ]

        args = retrain.parse_args(args)
        retrain.main(args)
示例#5
0
    def test05_retrain_noargs(self):
        shutil.rmtree(OUTPUT_DIR, ignore_errors=True)

        args = [
            "--data-dir",
            DATA_BIN,
            "--output-dir",
            OUTPUT_DIR,
            "--model-dir",
            MODEL_DIR,
            "--tags",
            "CLI_tests",
            "--epochs",
            f"{EPOCHS}",
            "--new-classes",
            "IN:GET_LOCATION,IN:GET_LOCATION_HOME,SL:POINT_ON_MAP,SL:CATEGORY_LOCATION",
        ]

        args = retrain.parse_args(args)
        retrain.main(args)