예제 #1
0
def config_modelloader_and_convert2mlp(config, load_pretrain=True):
    # load the required modelfile
    model_module = importlib.import_module(
        os.path.splitext(config["model_def"])[0])
    models = []
    model_names = []

    for model_config in config["models"]:
        if "ignore" in model_config and model_config["ignore"]:
            continue
        model_id = model_config["model_id"]
        model_names.append(model_id)
        model_class = getattr(model_module, model_config["model_class"])
        model_params = model_config["model_params"]
        m = model_class(**model_params)
        if "subsample" in model_config and model_config["subsample"]:
            keep = model_config["subsample_prob"]
            seed = model_config["subsample_seed"]
            m = add_feature_subsample(m, config["channel"],
                                      config["dimension"], keep, seed)
        # m.cuda()
        if load_pretrain:
            model_file = get_path(config, model_id, "model")
            #model_file += "_pretrain"
            print("Loading model file", model_file)
            checkpoint = torch.load(model_file)
            if isinstance(checkpoint['state_dict'], list):
                checkpoint['state_dict'] = checkpoint['state_dict'][0]
            new_state_dict = {}
            for k in checkpoint['state_dict'].keys():
                if "prev" in k:
                    pass
                else:
                    new_state_dict[k] = checkpoint['state_dict'][k]
            checkpoint['state_dict'] = new_state_dict
            """
            state_dict = m.state_dict()
            state_dict.update(checkpoint['state_dict'])
            m.load_state_dict(state_dict)
            # print(checkpoint['state_dict']['__mask_layer.weight'])
            """

            m.load_state_dict(checkpoint['state_dict'])
            print("convert to dense w")
            dense_m = convert_conv2d_dense(m)
            in_dim = model_params["in_dim"]
            in_ch = model_params["in_ch"]
            tmp = dense_m(torch.zeros(1, in_ch, in_dim, in_dim))
            dense_checkpoint_file = model_file.split(".pth")[0] + "_dense.pth"
            print("save dense checkpoint to {}".format(dense_checkpoint_file))

            save_checkpoint(dense_m, dense_checkpoint_file)

            mlp_m = load_checkpoint_to_mlpany(dense_checkpoint_file)
        # print(m)
        # models.append(m)
        models.append(mlp_m)
    return models, model_names
예제 #2
0
def config_modelloader(config, load_pretrain=False, cuda=False):
    # load the required modelfile
    model_module = importlib.import_module(
        os.path.splitext(config["model_def"])[0])
    models = []
    model_names = []
    for model_config in config["models"]:
        if "ignore" in model_config and model_config["ignore"]:
            continue
        model_id = model_config["model_id"]
        model_names.append(model_id)
        model_class = getattr(model_module, model_config["model_class"])
        model_params = model_config["model_params"]
        m = model_class(**model_params)
        if "subsample" in model_config and model_config["subsample"]:
            keep = model_config["subsample_prob"]
            seed = model_config["subsample_seed"]
            m = add_feature_subsample(m, config["channel"],
                                      config["dimension"], keep, seed)
        if cuda:
            m.cuda()
        if load_pretrain:
            model_file = get_path(config, model_id, "model")
            #model_file += "_pretrain"
            print("Loading model file", model_file)
            checkpoint = torch.load(model_file)
            if isinstance(checkpoint['state_dict'], list):
                checkpoint['state_dict'] = checkpoint['state_dict'][0]

            new_state_dict = {}
            for k in checkpoint['state_dict'].keys():
                if "prev" in k:
                    pass
                else:
                    new_state_dict[k] = checkpoint['state_dict'][k]
            checkpoint['state_dict'] = new_state_dict
            """
            state_dict = m.state_dict()
            state_dict.update(checkpoint['state_dict'])
            m.load_state_dict(state_dict)
            # print(checkpoint['state_dict']['__mask_layer.weight'])
            """
            m.load_state_dict(checkpoint['state_dict'])
        # print(m)
        models.append(m)
    return models, model_names