def load_model(pretrained_model, sentence, discrim_weights, discrim_meta, device='cpu', cached=False): with open(discrim_meta, 'r') as discrim_meta_file: meta = json.load(discrim_meta_file) meta['path'] = discrim_weights classifier = ClassificationHead( class_size=meta["class_size"], embed_size=meta["embed_size"]).to('cpu').eval() classifier.load_state_dict(torch.load(discrim_weights, map_location=device)) classifier.eval() model = Discriminator(pretrained_model=pretrained_model, classifier_head=classifier, cached_mode=cached, device=device) model.eval() classes = [c for i, c in enumerate(meta["class_vocab"])] predict(sentence, model, classes)
def load_classifier_head(weights_path, meta_path, device=DEVICE): with open(meta_path, 'r', encoding="utf8") as f: meta_params = json.load(f) classifier_head = ClassificationHead( class_size=meta_params['class_size'], embed_size=meta_params['embed_size']).to(device) classifier_head.load_state_dict( torch.load(weights_path, map_location=device)) classifier_head.eval() return classifier_head, meta_params
def get_classifier( name: Optional[str], class_label: Union[str, int], device: str, verbosity_level: int = REGULAR ) -> Tuple[Optional[ClassificationHead], Optional[int]]: """ Загружаем предварительно сохранный малый торчевский классификатор-дискриминатор по имени. """ if name is None: return None, None params = DISCRIMINATOR_MODELS_PARAMS[name] classifier = ClassificationHead(class_size=params['class_size'], embed_size=params['embed_size']).to(device) if "url" in params: resolved_archive_file = cached_path(params["url"]) elif "path" in params: resolved_archive_file = params["path"] else: raise ValueError("Either url or path have to be specified " "in the discriminator model parameters") classifier.load_state_dict( torch.load(resolved_archive_file, map_location=device)) classifier.eval() if isinstance(class_label, str): if class_label in params["class_vocab"]: label_id = params["class_vocab"][class_label] else: label_id = params["default_class"] if verbosity_level >= REGULAR: print("class_label {} not in class_vocab".format(class_label)) print("available values are: {}".format(params["class_vocab"])) print("using default class {}".format(label_id)) elif isinstance(class_label, int): if class_label in set(params["class_vocab"].values()): label_id = class_label else: label_id = params["default_class"] if verbosity_level >= REGULAR: print("class_label {} not in class_vocab".format(class_label)) print("available values are: {}".format(params["class_vocab"])) print("using default class {}".format(label_id)) else: label_id = params["default_class"] return classifier, label_id
def get_classifier_new(model_path, meta_path,device): with open(meta_path, 'r') as discrim_meta_file: params = json.load(discrim_meta_file) classifier = ClassificationHead( class_size=params['class_size'], embed_size=params['embed_size'] ).to(device) classifier.load_state_dict( torch.load(model_path, map_location=device)) classifier.eval() return classifier
def get_classifier( model, name: Optional[str], class_label: Union[str, int], device: str) -> Tuple[Optional[ClassificationHead], Optional[int]]: if name is None: return None, None params = DISCRIMINATOR_MODELS_PARAMS[name] classifier = ClassificationHead(class_size=params["class_size"], embed_size=params["embed_size"]).to(device) if "url" in params: resolved_archive_file = cached_path(params["url"]) elif "path" in params: resolved_archive_file = params["path"] else: raise ValueError( "Either url or path have to be specified in the discriminator model parameters" ) classifier.load_state_dict( torch.load(resolved_archive_file, map_location=device)) classifier.eval() if isinstance(class_label, str): if class_label in params["class_vocab"]: label_id = params["class_vocab"][class_label] else: label_id = params["default_class"] print("class_label {} not in class_vocab".format(class_label)) print("available values are: {}".format(params["class_vocab"])) print("using default class {}".format(label_id)) elif isinstance(class_label, int): if class_label in set(params["class_vocab"].values()): label_id = class_label else: label_id = params["default_class"] print("class_label {} not in class_vocab".format(class_label)) print("available values are: {}".format(params["class_vocab"])) print("using default class {}".format(label_id)) else: label_id = params["default_class"] return classifier, label_id
def get_classifier(discrim_meta: Optional[dict], device: str) -> Optional[ClassificationHead]: if discrim_meta is None: return None, None params = discrim_meta classifier = ClassificationHead(class_size=params['class_size'], embed_size=params['embed_size']).to(device) if "url" in params: resolved_archive_file = cached_path(params["url"]) elif "path" in params: resolved_archive_file = params["path"] else: raise ValueError("Either url or path have to be specified " "in the discriminator model parameters") classifier.load_state_dict( torch.load(resolved_archive_file, map_location=device)) classifier.eval() return classifier
def get_classifier( discrim_meta: Optional[dict], class_label: Union[str, int], device: str) -> Tuple[Optional[ClassificationHead], Optional[int]]: if discrim_meta is None: return None, None params = discrim_meta classifier = ClassificationHead(class_size=params['class_size'], embed_size=params['embed_size']).to(device) if "url" in params: resolved_archive_file = cached_path(params["url"]) elif "path" in params: resolved_archive_file = params["path"] else: raise ValueError("Either url or path have to be specified " "in the discriminator model parameters") classifier.load_state_dict( torch.load(resolved_archive_file, map_location=device)) classifier.eval() if isinstance(class_label, str): if class_label in params["class_vocab"]: label_id = params["class_vocab"][class_label] else: label_id = params["default_class"] elif isinstance(class_label, int): if class_label in set(params["class_vocab"].values()): label_id = class_label else: label_id = params["default_class"] else: label_id = params["default_class"] return classifier, label_id