def get_model_data(name, timings, errors): """Get model data for a single model.""" # Load model config reset_cfg() cfg.merge_from_file(model_zoo.get_config_file(name)) config_url, _, model_id, _, weight_url_full = model_zoo.get_model_info( name) # Get model complexity cx = net.complexity(builders.get_model()) # Inference time is measured in ms with a reference batch_size and num_gpus batch_size, num_gpus = 64, 1 reference = batch_size / cfg.TEST.BATCH_SIZE * cfg.NUM_GPUS / num_gpus infer_time = timings[name]["test_fw_time"] * reference * 1000 # Training time is measured in hours for 100 epochs over the ImageNet train set iterations = 1281167 / cfg.TRAIN.BATCH_SIZE * 100 train_time = timings[name]["train_fw_bw_time"] * iterations / 3600 # Gather all data about the model return { "config_url": "configs/" + config_url, "flops": round(cx["flops"] / 1e9, 1), "params": round(cx["params"] / 1e6, 1), "acts": round(cx["acts"] / 1e6, 1), "batch_size": cfg.TRAIN.BATCH_SIZE, "infer_time": round(infer_time), "train_time": round(train_time, 1), "error": round(errors[name]["top1_err"], 1), "model_id": model_id, "weight_url": weight_url_full, }
def test_timing(key): """Measure the timing of a single model.""" reset_cfg() merge_from_file(model_zoo.get_config_file(key)) cfg.PREC_TIME.WARMUP_ITER, cfg.PREC_TIME.NUM_ITER = 5, 50 cfg.OUT_DIR, cfg.LOG_DEST = tempfile.mkdtemp(), "file" dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.time_model) log_file = os.path.join(cfg.OUT_DIR, "stdout.log") data = logging.sort_log_data(logging.load_log_data(log_file))["iter_times"] shutil.rmtree(cfg.OUT_DIR) return data
def test_error(key): """Measure the error of a single model.""" reset_cfg() merge_from_file(model_zoo.get_config_file(key)) cfg.TEST.WEIGHTS = model_zoo.get_weights_file(key) cfg.OUT_DIR, cfg.LOG_DEST = tempfile.mkdtemp(), "file" dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.test_model) log_file = os.path.join(cfg.OUT_DIR, "stdout.log") data = logging.sort_log_data(logging.load_log_data(log_file))["test_epoch"] data = {"top1_err": data["top1_err"][-1], "top5_err": data["top5_err"][-1]} shutil.rmtree(cfg.OUT_DIR) return data
def build_model(name, pretrained=False, cfg_list=()): """Constructs a predefined model (note: loads global config as well).""" # Load the config reset_cfg() config_file = get_config_file(name) cfg.merge_from_file(config_file) cfg.merge_from_list(cfg_list) # Construct model model = builders.build_model() # Load pretrained weights if pretrained: weights_file = get_weights_file(name) cp.load_checkpoint(weights_file, model) return model
def sample_cfgs(seed): """Samples chunk configs and return those that are unique and valid.""" # Fix RNG seed (every call to this function should use a unique seed) np.random.seed(seed) setup_cfg = sweep_cfg.SETUP cfgs = {} for _ in range(setup_cfg.CHUNK_SIZE): # Sample parameters [key, val, ...] list based on the samplers params = samplers.sample_parameters(setup_cfg.SAMPLERS) # Check if config is unique, if not continue key = zip(params[0::2], params[1::2]) key = " ".join(["{} {}".format(k, v) for k, v in key]) if key in cfgs: continue # Generate config from parameters reset_cfg() cfg.merge_from_other_cfg(setup_cfg.BASE_CFG) cfg.merge_from_list(params) # Check if config is valid, if not continue is_valid = samplers.check_regnet_constraints(setup_cfg.CONSTRAINTS) if not is_valid: continue # Special logic for dealing w model scaling (side effect is to standardize cfg) if cfg.MODEL.TYPE in ["anynet", "effnet", "regnet"]: scaler.scale_model() # Check if config is valid, if not continue is_valid = samplers.check_complexity_constraints(setup_cfg.CONSTRAINTS) if not is_valid: continue # Set config description to key cfg.DESC = key # Store copy of config if unique and valid cfgs[key] = cfg.clone() # Stop sampling if already reached quota if len(cfgs) == setup_cfg.NUM_CONFIGS: break return cfgs
def test_complexity(key): """Measure the complexity of a single model.""" reset_cfg() cfg_file = os.path.join(_PYCLS_DIR, key) merge_from_file(cfg_file) return net.complexity(builders.get_model())