Exemple #1
0
    def _load(model_name_or_path: str, load_weights: bool = False):
        if not os.path.exists(model_name_or_path):
            if model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
                if not os.path.exists(f"saved/{model_name_or_path}"):
                    archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[
                        model_name_or_path]
                    download_url(archive_file, "saved/",
                                 f"{model_name_or_path}.zip")
                    untar("saved/", f"{model_name_or_path}.zip")
                model_name_or_path = f"saved/{model_name_or_path}"
            else:
                raise KeyError("Cannot find the pretrained model {}".format(
                    model_name_or_path))

        try:
            version = open(os.path.join(model_name_or_path,
                                        "version")).readline().strip()
        except Exception:
            version = None

        bert_config = BertConfig.from_dict(
            json.load(
                open(os.path.join(model_name_or_path, "bert_config.json"))))
        tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
        if version == "2":
            bert_model = OAGMetaInfoBertModel(bert_config, tokenizer)
        else:
            bert_model = OAGBertPretrainingModel(bert_config)

        model_weight_path = os.path.join(model_name_or_path,
                                         "pytorch_model.bin")
        if load_weights and os.path.exists(model_weight_path):
            bert_model.load_state_dict(torch.load(model_weight_path))

        return bert_config, tokenizer, bert_model
    def load_data(self):
        rpath = "data/supervised_classification/" + self.dataset
        zip_name = self.dataset + ".zip"
        if not os.path.isdir(rpath):
            download_url(dataset_url_dict[self.dataset], rpath, name=zip_name)
            untar(rpath, zip_name)

        # dest_dir = '../oagbert/benchmark/raid/yinda/oagbert_v1.5/%s/supervised' % self.dataset
        dest_dir = rpath

        def _load(name):
            data = []
            for line in open("%s/%s.jsonl" % (dest_dir, name)):
                data.append(json.loads(line.strip()))
            return data

        train_data, dev_data, test_data = _load("train"), _load("dev"), _load(
            "test")
        return train_data, dev_data, test_data
Exemple #3
0
 def download(self):
     fname = "{}.zip".format(self.name.lower())
     download_url("{}{}.zip&dl=1".format(self.url, self.name.lower()),
                  self.raw_dir, fname)
     untar(self.raw_dir, fname)
Exemple #4
0
 def download(self):
     download_url(self.url, self.raw_dir, name=self.name + ".zip")
     untar(self.raw_dir, self.name + ".zip")
Exemple #5
0
 def download(self):
     fname = "{}.tgz".format(self.name.lower())
     download_url("{}{}.tgz&dl=1".format(base_url, self.name), self.raw_dir,
                  fname)
     untar(self.raw_dir, fname)
Exemple #6
0
 def download(self):
     filename = self.name + '.zip'
     download_url(self.url, self.processed_dir, name=filename)
     untar(self.processed_dir, filename)
     print(f'downloaded to {self.processed_dir}')