def get_model_based_on_rho(rho, arch, config_only=False, model_config_overrides={}): # extra receptive checking extra_kernal_rf = rho - 7 model_config = { "arch": arch, "base_channels": 128, "block_type": "basic", "depth": 26, "input_shape": [10, 2, -1, -1], "multi_label": False, "n_classes": 10, "prediction_threshold": 0.4, "stage1": { "maxpool": [1, 2, 4], "k1s": [ 3, 3 - (-extra_kernal_rf > 6) * 2, 3 - (-extra_kernal_rf > 4) * 2, 3 - (-extra_kernal_rf > 2) * 2 ], "k2s": [ 1, 3 - (-extra_kernal_rf > 5) * 2, 3 - (-extra_kernal_rf > 3) * 2, 3 - (-extra_kernal_rf > 1) * 2 ] }, "stage2": { "maxpool": [], "k1s": [ 3 - (-extra_kernal_rf > 0) * 2, 1 + (extra_kernal_rf > 1) * 2, 1 + (extra_kernal_rf > 3) * 2, 1 + (extra_kernal_rf > 5) * 2 ], "k2s": [ 1 + (extra_kernal_rf > 0) * 2, 1 + (extra_kernal_rf > 2) * 2, 1 + (extra_kernal_rf > 4) * 2, 1 + (extra_kernal_rf > 6) * 2 ] }, "stage3": { "maxpool": [], "k1s": [ 1 + (extra_kernal_rf > 7) * 2, 1 + (extra_kernal_rf > 9) * 2, 1 + (extra_kernal_rf > 11) * 2, 1 + (extra_kernal_rf > 13) * 2 ], "k2s": [ 1 + (extra_kernal_rf > 8) * 2, 1 + (extra_kernal_rf > 10) * 2, 1 + (extra_kernal_rf > 12) * 2, 1 + (extra_kernal_rf > 14) * 2 ] }, "block_type": "basic", "use_bn": True, "weight_init": "fixup" } # override model_config model_config = update_dict(model_config, model_config_overrides) if config_only: return model_config return Network(model_config)
'the pre-trained model path to load, in this case the model is only evaluated' ) args = parser.parse_args() if args.load is None: with open("configs/cp_resnet.json", "r") as text_file: default_conf = json.load(text_file) else: with open("configs/cp_resnet_eval.json", "r") as text_file: default_conf = json.load(text_file) # overriding the database config print(f"\nSelected training dataset is configs/datasets{args.dataset} ...\n") with open("configs/datasets/" + args.dataset, "r") as text_file: dataset_conf = json.load(text_file) default_conf = utils_funcs.update_dict(default_conf, dataset_conf) default_conf['out_dir'] = default_conf['out_dir'].replace( "cp_resnet", args.arch) + str( datetime.datetime.now().strftime('%b%d_%H.%M.%S')) print("The experiment outputs will be found at: ", default_conf['out_dir']) tensorboard_write_path = default_conf['out_dir'].replace("out", "runs", 1) print("The experiment tesnorboard can be accessed: tensorboard --logdir ", tensorboard_write_path) print("Rho value : ", args.rho) print("Use Mix-up : ", args.mixup) arch = importlib.import_module('models.{}'.format(args.arch))