def get_model_based_on_rho(rho,
                           arch,
                           config_only=False,
                           model_config_overrides={}):
    # extra receptive checking
    extra_kernal_rf = rho - 7
    model_config = {
        "arch": arch,
        "base_channels": 128,
        "block_type": "basic",
        "depth": 26,
        "input_shape": [10, 2, -1, -1],
        "multi_label": False,
        "n_classes": 10,
        "prediction_threshold": 0.4,
        "stage1": {
            "maxpool": [1, 2, 4],
            "k1s": [
                3, 3 - (-extra_kernal_rf > 6) * 2,
                3 - (-extra_kernal_rf > 4) * 2, 3 - (-extra_kernal_rf > 2) * 2
            ],
            "k2s": [
                1, 3 - (-extra_kernal_rf > 5) * 2,
                3 - (-extra_kernal_rf > 3) * 2, 3 - (-extra_kernal_rf > 1) * 2
            ]
        },
        "stage2": {
            "maxpool": [],
            "k1s": [
                3 - (-extra_kernal_rf > 0) * 2, 1 + (extra_kernal_rf > 1) * 2,
                1 + (extra_kernal_rf > 3) * 2, 1 + (extra_kernal_rf > 5) * 2
            ],
            "k2s": [
                1 + (extra_kernal_rf > 0) * 2, 1 + (extra_kernal_rf > 2) * 2,
                1 + (extra_kernal_rf > 4) * 2, 1 + (extra_kernal_rf > 6) * 2
            ]
        },
        "stage3": {
            "maxpool": [],
            "k1s": [
                1 + (extra_kernal_rf > 7) * 2, 1 + (extra_kernal_rf > 9) * 2,
                1 + (extra_kernal_rf > 11) * 2, 1 + (extra_kernal_rf > 13) * 2
            ],
            "k2s": [
                1 + (extra_kernal_rf > 8) * 2, 1 + (extra_kernal_rf > 10) * 2,
                1 + (extra_kernal_rf > 12) * 2, 1 + (extra_kernal_rf > 14) * 2
            ]
        },
        "block_type": "basic",
        "use_bn": True,
        "weight_init": "fixup"
    }
    # override model_config
    model_config = update_dict(model_config, model_config_overrides)
    if config_only:
        return model_config
    return Network(model_config)
예제 #2
0
    'the pre-trained model path to load, in this case the model is only evaluated'
)

args = parser.parse_args()
if args.load is None:
    with open("configs/cp_resnet.json", "r") as text_file:
        default_conf = json.load(text_file)
else:
    with open("configs/cp_resnet_eval.json", "r") as text_file:
        default_conf = json.load(text_file)

# overriding the database config
print(f"\nSelected training dataset is configs/datasets{args.dataset} ...\n")
with open("configs/datasets/" + args.dataset, "r") as text_file:
    dataset_conf = json.load(text_file)
default_conf = utils_funcs.update_dict(default_conf, dataset_conf)

default_conf['out_dir'] = default_conf['out_dir'].replace(
    "cp_resnet", args.arch) + str(
        datetime.datetime.now().strftime('%b%d_%H.%M.%S'))

print("The experiment outputs will be found at: ", default_conf['out_dir'])
tensorboard_write_path = default_conf['out_dir'].replace("out", "runs", 1)
print("The experiment tesnorboard can be accessed: tensorboard --logdir  ",
      tensorboard_write_path)

print("Rho value : ", args.rho)
print("Use Mix-up : ", args.mixup)

arch = importlib.import_module('models.{}'.format(args.arch))