def config_modelloader_and_convert2mlp(config, load_pretrain=True): # load the required modelfile model_module = importlib.import_module( os.path.splitext(config["model_def"])[0]) models = [] model_names = [] for model_config in config["models"]: if "ignore" in model_config and model_config["ignore"]: continue model_id = model_config["model_id"] model_names.append(model_id) model_class = getattr(model_module, model_config["model_class"]) model_params = model_config["model_params"] m = model_class(**model_params) if "subsample" in model_config and model_config["subsample"]: keep = model_config["subsample_prob"] seed = model_config["subsample_seed"] m = add_feature_subsample(m, config["channel"], config["dimension"], keep, seed) # m.cuda() if load_pretrain: model_file = get_path(config, model_id, "model") #model_file += "_pretrain" print("Loading model file", model_file) checkpoint = torch.load(model_file) if isinstance(checkpoint['state_dict'], list): checkpoint['state_dict'] = checkpoint['state_dict'][0] new_state_dict = {} for k in checkpoint['state_dict'].keys(): if "prev" in k: pass else: new_state_dict[k] = checkpoint['state_dict'][k] checkpoint['state_dict'] = new_state_dict """ state_dict = m.state_dict() state_dict.update(checkpoint['state_dict']) m.load_state_dict(state_dict) # print(checkpoint['state_dict']['__mask_layer.weight']) """ m.load_state_dict(checkpoint['state_dict']) print("convert to dense w") dense_m = convert_conv2d_dense(m) in_dim = model_params["in_dim"] in_ch = model_params["in_ch"] tmp = dense_m(torch.zeros(1, in_ch, in_dim, in_dim)) dense_checkpoint_file = model_file.split(".pth")[0] + "_dense.pth" print("save dense checkpoint to {}".format(dense_checkpoint_file)) save_checkpoint(dense_m, dense_checkpoint_file) mlp_m = load_checkpoint_to_mlpany(dense_checkpoint_file) # print(m) # models.append(m) models.append(mlp_m) return models, model_names
def config_modelloader(config, load_pretrain=False, cuda=False): # load the required modelfile model_module = importlib.import_module( os.path.splitext(config["model_def"])[0]) models = [] model_names = [] for model_config in config["models"]: if "ignore" in model_config and model_config["ignore"]: continue model_id = model_config["model_id"] model_names.append(model_id) model_class = getattr(model_module, model_config["model_class"]) model_params = model_config["model_params"] m = model_class(**model_params) if "subsample" in model_config and model_config["subsample"]: keep = model_config["subsample_prob"] seed = model_config["subsample_seed"] m = add_feature_subsample(m, config["channel"], config["dimension"], keep, seed) if cuda: m.cuda() if load_pretrain: model_file = get_path(config, model_id, "model") #model_file += "_pretrain" print("Loading model file", model_file) checkpoint = torch.load(model_file) if isinstance(checkpoint['state_dict'], list): checkpoint['state_dict'] = checkpoint['state_dict'][0] new_state_dict = {} for k in checkpoint['state_dict'].keys(): if "prev" in k: pass else: new_state_dict[k] = checkpoint['state_dict'][k] checkpoint['state_dict'] = new_state_dict """ state_dict = m.state_dict() state_dict.update(checkpoint['state_dict']) m.load_state_dict(state_dict) # print(checkpoint['state_dict']['__mask_layer.weight']) """ m.load_state_dict(checkpoint['state_dict']) # print(m) models.append(m) return models, model_names