def _load(model_name_or_path: str, load_weights: bool = False): if not os.path.exists(model_name_or_path): if model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: if not os.path.exists(f"saved/{model_name_or_path}"): archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[ model_name_or_path] download_url(archive_file, "saved/", f"{model_name_or_path}.zip") untar("saved/", f"{model_name_or_path}.zip") model_name_or_path = f"saved/{model_name_or_path}" else: raise KeyError("Cannot find the pretrained model {}".format( model_name_or_path)) try: version = open(os.path.join(model_name_or_path, "version")).readline().strip() except Exception: version = None bert_config = BertConfig.from_dict( json.load( open(os.path.join(model_name_or_path, "bert_config.json")))) tokenizer = BertTokenizer.from_pretrained(model_name_or_path) if version == "2": bert_model = OAGMetaInfoBertModel(bert_config, tokenizer) else: bert_model = OAGBertPretrainingModel(bert_config) model_weight_path = os.path.join(model_name_or_path, "pytorch_model.bin") if load_weights and os.path.exists(model_weight_path): bert_model.load_state_dict(torch.load(model_weight_path)) return bert_config, tokenizer, bert_model
def load_data(self): rpath = "data/supervised_classification/" + self.dataset zip_name = self.dataset + ".zip" if not os.path.isdir(rpath): download_url(dataset_url_dict[self.dataset], rpath, name=zip_name) untar(rpath, zip_name) # dest_dir = '../oagbert/benchmark/raid/yinda/oagbert_v1.5/%s/supervised' % self.dataset dest_dir = rpath def _load(name): data = [] for line in open("%s/%s.jsonl" % (dest_dir, name)): data.append(json.loads(line.strip())) return data train_data, dev_data, test_data = _load("train"), _load("dev"), _load( "test") return train_data, dev_data, test_data
def download(self): fname = "{}.zip".format(self.name.lower()) download_url("{}{}.zip&dl=1".format(self.url, self.name.lower()), self.raw_dir, fname) untar(self.raw_dir, fname)
def download(self): download_url(self.url, self.raw_dir, name=self.name + ".zip") untar(self.raw_dir, self.name + ".zip")
def download(self): fname = "{}.tgz".format(self.name.lower()) download_url("{}{}.tgz&dl=1".format(base_url, self.name), self.raw_dir, fname) untar(self.raw_dir, fname)
def download(self): filename = self.name + '.zip' download_url(self.url, self.processed_dir, name=filename) untar(self.processed_dir, filename) print(f'downloaded to {self.processed_dir}')