Пример #1
0
def load_model(pretrained_model,
               sentence,
               discrim_weights,
               discrim_meta,
               device='cpu',
               cached=False):

    with open(discrim_meta, 'r') as discrim_meta_file:
        meta = json.load(discrim_meta_file)
    meta['path'] = discrim_weights

    classifier = ClassificationHead(
        class_size=meta["class_size"],
        embed_size=meta["embed_size"]).to('cpu').eval()

    classifier.load_state_dict(torch.load(discrim_weights,
                                          map_location=device))
    classifier.eval()

    model = Discriminator(pretrained_model=pretrained_model,
                          classifier_head=classifier,
                          cached_mode=cached,
                          device=device)
    model.eval()

    classes = [c for i, c in enumerate(meta["class_vocab"])]
    predict(sentence, model, classes)
Пример #2
0
def load_classifier_head(weights_path, meta_path, device=DEVICE):
    with open(meta_path, 'r', encoding="utf8") as f:
        meta_params = json.load(f)
    classifier_head = ClassificationHead(
        class_size=meta_params['class_size'],
        embed_size=meta_params['embed_size']).to(device)
    classifier_head.load_state_dict(
        torch.load(weights_path, map_location=device))
    classifier_head.eval()
    return classifier_head, meta_params
Пример #3
0
def get_classifier(
    name: Optional[str],
    class_label: Union[str, int],
    device: str,
    verbosity_level: int = REGULAR
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
    """
    Загружаем предварительно сохранный малый торчевский классификатор-дискриминатор по имени.
    """
    if name is None:
        return None, None

    params = DISCRIMINATOR_MODELS_PARAMS[name]
    classifier = ClassificationHead(class_size=params['class_size'],
                                    embed_size=params['embed_size']).to(device)
    if "url" in params:
        resolved_archive_file = cached_path(params["url"])
    elif "path" in params:
        resolved_archive_file = params["path"]
    else:
        raise ValueError("Either url or path have to be specified "
                         "in the discriminator model parameters")
    classifier.load_state_dict(
        torch.load(resolved_archive_file, map_location=device))
    classifier.eval()

    if isinstance(class_label, str):
        if class_label in params["class_vocab"]:
            label_id = params["class_vocab"][class_label]
        else:
            label_id = params["default_class"]
            if verbosity_level >= REGULAR:
                print("class_label {} not in class_vocab".format(class_label))
                print("available values are: {}".format(params["class_vocab"]))
                print("using default class {}".format(label_id))

    elif isinstance(class_label, int):
        if class_label in set(params["class_vocab"].values()):
            label_id = class_label
        else:
            label_id = params["default_class"]
            if verbosity_level >= REGULAR:
                print("class_label {} not in class_vocab".format(class_label))
                print("available values are: {}".format(params["class_vocab"]))
                print("using default class {}".format(label_id))

    else:
        label_id = params["default_class"]

    return classifier, label_id
Пример #4
0
def get_classifier_new(model_path, meta_path,device):
   
    with open(meta_path, 'r') as discrim_meta_file:
        params = json.load(discrim_meta_file)

    classifier = ClassificationHead(
        class_size=params['class_size'],
        embed_size=params['embed_size']
    ).to(device)

    classifier.load_state_dict(
        torch.load(model_path, map_location=device))
    classifier.eval()

    return classifier
Пример #5
0
def get_classifier(
        model, name: Optional[str], class_label: Union[str, int],
        device: str) -> Tuple[Optional[ClassificationHead], Optional[int]]:

    if name is None:
        return None, None

    params = DISCRIMINATOR_MODELS_PARAMS[name]
    classifier = ClassificationHead(class_size=params["class_size"],
                                    embed_size=params["embed_size"]).to(device)
    if "url" in params:
        resolved_archive_file = cached_path(params["url"])
    elif "path" in params:
        resolved_archive_file = params["path"]
    else:
        raise ValueError(
            "Either url or path have to be specified in the discriminator model parameters"
        )

    classifier.load_state_dict(
        torch.load(resolved_archive_file, map_location=device))
    classifier.eval()

    if isinstance(class_label, str):
        if class_label in params["class_vocab"]:
            label_id = params["class_vocab"][class_label]
        else:
            label_id = params["default_class"]
            print("class_label {} not in class_vocab".format(class_label))
            print("available values are: {}".format(params["class_vocab"]))
            print("using default class {}".format(label_id))

    elif isinstance(class_label, int):
        if class_label in set(params["class_vocab"].values()):
            label_id = class_label
        else:
            label_id = params["default_class"]
            print("class_label {} not in class_vocab".format(class_label))
            print("available values are: {}".format(params["class_vocab"]))
            print("using default class {}".format(label_id))

    else:
        label_id = params["default_class"]

    return classifier, label_id
Пример #6
0
def get_classifier(discrim_meta: Optional[dict],
                   device: str) -> Optional[ClassificationHead]:
    if discrim_meta is None:
        return None, None

    params = discrim_meta
    classifier = ClassificationHead(class_size=params['class_size'],
                                    embed_size=params['embed_size']).to(device)
    if "url" in params:
        resolved_archive_file = cached_path(params["url"])
    elif "path" in params:
        resolved_archive_file = params["path"]
    else:
        raise ValueError("Either url or path have to be specified "
                         "in the discriminator model parameters")
    classifier.load_state_dict(
        torch.load(resolved_archive_file, map_location=device))
    classifier.eval()

    return classifier
def get_classifier(
        discrim_meta: Optional[dict], class_label: Union[str, int],
        device: str) -> Tuple[Optional[ClassificationHead], Optional[int]]:
    if discrim_meta is None:
        return None, None

    params = discrim_meta
    classifier = ClassificationHead(class_size=params['class_size'],
                                    embed_size=params['embed_size']).to(device)
    if "url" in params:
        resolved_archive_file = cached_path(params["url"])
    elif "path" in params:
        resolved_archive_file = params["path"]
    else:
        raise ValueError("Either url or path have to be specified "
                         "in the discriminator model parameters")
    classifier.load_state_dict(
        torch.load(resolved_archive_file, map_location=device))
    classifier.eval()

    if isinstance(class_label, str):
        if class_label in params["class_vocab"]:
            label_id = params["class_vocab"][class_label]
        else:
            label_id = params["default_class"]

    elif isinstance(class_label, int):
        if class_label in set(params["class_vocab"].values()):
            label_id = class_label
        else:
            label_id = params["default_class"]

    else:
        label_id = params["default_class"]

    return classifier, label_id