def test18_retrain_weight_consolidation(self): shutil.rmtree(OUTPUT_DIR, ignore_errors=True) args = [ "--data-dir", DATA_BIN, "--output-dir", OUTPUT_DIR, "--model-dir", MODEL_DIR, "--tags", "CLI_tests", "--lr", "0.3", "--epochs", "5", "--log-every", "10", "--new-classes", "IN:GET_LOCATION,IN:GET_LOCATION_HOME,SL:POINT_ON_MAP,SL:CATEGORY_LOCATION", "--batch-size", "16", "--weight-consolidation", "0.1", ] args = retrain.parse_args(args) retrain.main(args)
def test16_retrain_average(self): shutil.rmtree(OUTPUT_DIR, ignore_errors=True) args = [ "--data-dir", DATA_BIN, "--output-dir", OUTPUT_DIR, "--model-dir", MODEL_DIR, "--tags", "CLI_tests", "--lr", "0.3", "--epochs", "5", "--log-every", "10", "--new-classes", "IN:GET_LOCATION,IN:GET_LOCATION_HOME,SL:POINT_ON_MAP,SL:CATEGORY_LOCATION", "--batch-size", "16", "--no-lr-scheduler", "--average-checkpoints", "--new-model-weight", "1", ] args = retrain.parse_args(args) retrain.main(args)
def test06_retrain_old_data001_merge(self): shutil.rmtree(OUTPUT_DIR, ignore_errors=True) args = [ "--data-dir", DATA_BIN, "--output-dir", OUTPUT_DIR, "--model-dir", MODEL_DIR, "--tags", "CLI_tests", "--epochs", "1", "--new-classes", "IN:GET_LOCATION,IN:GET_LOCATION_HOME,SL:POINT_ON_MAP,SL:CATEGORY_LOCATION", "--old-data-amount", "0.01", "--early-stopping", "16", "--old-data-sampling-method", "merge_subset", ] args = retrain.parse_args(args) retrain.main(args)
def test14_retrain_limit_iters(self): shutil.rmtree(OUTPUT_DIR, ignore_errors=True) args = [ "--data-dir", DATA_BIN, "--output-dir", OUTPUT_DIR, "--model-dir", MODEL_DIR, "--tags", "CLI_tests", "--epochs", "1", "--min-steps", "3", "--max-steps", "10", "--new-classes", "IN:GET_LOCATION,IN:GET_LOCATION_HOME,SL:POINT_ON_MAP,SL:CATEGORY_LOCATION", "--label-smoothing", "0.18", ] args = retrain.parse_args(args) retrain.main(args)
def test05_retrain_noargs(self): shutil.rmtree(OUTPUT_DIR, ignore_errors=True) args = [ "--data-dir", DATA_BIN, "--output-dir", OUTPUT_DIR, "--model-dir", MODEL_DIR, "--tags", "CLI_tests", "--epochs", f"{EPOCHS}", "--new-classes", "IN:GET_LOCATION,IN:GET_LOCATION_HOME,SL:POINT_ON_MAP,SL:CATEGORY_LOCATION", ] args = retrain.parse_args(args) retrain.main(args)