Ejemplo n.º 1
0
 def __init__(self,
              class_size,
              pretrained_model="gpt2-medium",
              cached_mode=False,
              device='cpu'):
     super(Discriminator, self).__init__()
     self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
     self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model)
     self.embed_size = self.encoder.transformer.config.hidden_size
     self.classifier_head = ClassificationHead(class_size=class_size,
                                               embed_size=self.embed_size)
     self.cached_mode = cached_mode
     self.device = device
Ejemplo n.º 2
0
def get_classifier(
    name: Optional[str],
    class_label: Union[str, int],
    device: str,
    verbosity_level: int = REGULAR
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
    if name is None:
        return None, None

    params = DISCRIMINATOR_MODELS_PARAMS[name]
    classifier = ClassificationHead(class_size=params['class_size'],
                                    embed_size=params['embed_size']).to(device)
    if "url" in params:
        resolved_archive_file = cached_path(params["url"])
    elif "path" in params:
        resolved_archive_file = params["path"]
    else:
        raise ValueError("Either url or path have to be specified "
                         "in the discriminator model parameters")
    classifier.load_state_dict(
        torch.load(resolved_archive_file, map_location=device))
    classifier.eval()

    if isinstance(class_label, str):
        if class_label in params["class_vocab"]:
            label_id = params["class_vocab"][class_label]
        else:
            label_id = params["default_class"]
            if verbosity_level >= REGULAR:
                print("class_label {} not in class_vocab".format(class_label))
                print("available values are: {}".format(params["class_vocab"]))
                print("using default class {}".format(label_id))

    elif isinstance(class_label, int):
        if class_label in set(params["class_vocab"].values()):
            label_id = class_label
        else:
            label_id = params["default_class"]
            if verbosity_level >= REGULAR:
                print("class_label {} not in class_vocab".format(class_label))
                print("available values are: {}".format(params["class_vocab"]))
                print("using default class {}".format(label_id))

    else:
        label_id = params["default_class"]

    return classifier, label_id
Ejemplo n.º 3
0
    def load_attribute_model(self):
        print(f"Loading attribute classifier model")

        with open(self.meta_path, 'r', encoding="utf8") as f:
            self.meta_params = json.load(f)

        print(f"\t{self.meta_params}")
        print(f"\t{self.meta_params}")

        self.classifier = ClassificationHead(
            class_size=self.meta_params['class_size'],
            embed_size=self.meta_params['embed_size']
        ).to(self.device)

        self.classifier.load_state_dict(
            torch.load(self.weights_path, map_location=self.device))
        self.classifier.eval()