示例#1
0
def test_get_arch_config(arch: str):
    arch_configs = fastface.list_arch_configs(arch)
    for arch_config in arch_configs:
        config = fastface.get_arch_config(arch, arch_config)
        assert isinstance(
            config, Dict
        ), f"{arch}.{arch_config} must be dictionary but found: {type(config)}"
示例#2
0
def test_list_arch_configs(arch: str):
    arch_configs = fastface.list_arch_configs(arch)
    assert isinstance(
        arch_configs,
        List), f"returned value must be list but found:{type(arch_configs)}"
    for arch_config in arch_configs:
        assert isinstance(
            arch_config, str
        ), f"architecture config must contain string but found:{type(arch_config)}"
def parse_arguments():
    ap = argparse.ArgumentParser()
    ap.add_argument('--input-mx-model',
                    '-im',
                    type=str,
                    help='mxnet .params model to convert',
                    required=True)

    ap.add_argument('--model-configuration',
                    '-mc',
                    type=str,
                    choices=fastface.list_arch_configs("lffd"),
                    default='560_25L_8S')

    ap.add_argument('--output-path', '-o', type=str, default='./')

    return ap.parse_args()
hparams = {
    "learning_rate": 0.1,
    "momentum": 0.9,
    "weight_decay": 0.00001,
    "milestones": [500000, 1000000, 1500000],
    "gamma": 0.1,
    "ratio": 10
}

# checkout available architectures to train
print(ff.list_archs())
# ["lffd"]
arch = "lffd"

# checkout available configs for the architecture
print(ff.list_arch_configs(arch))
# ["original", "slim"]
config = "slim"

# build pl.LightningModule with random weights
model = ff.FaceDetector.build(arch,
                              config=config,
                              preprocess=preprocess,
                              hparams=hparams)

# add average precision pl.metrics.Metric to the model
model.add_metric("average_precision",
                 ff.metric.AveragePrecision(iou_threshold=0.5))

model_save_name = "{}_{}_{}_best".format(arch, config, "fddb")
ckpt_save_path = "./checkpoints"
def build_module_args() -> Tuple:
    for arch in ff.list_archs():
        for config in ff.list_arch_configs(arch):
            yield (arch, config)