def __init__(self, embeddings_path: str = None, lowercase_tokens: bool = True) -> None: if embeddings_path: if PathManager.isdir(embeddings_path): serialized_embed_path = os.path.join( embeddings_path, PackageFileName.SERIALIZED_EMBED) raw_embeddings_path = os.path.join(embeddings_path, PackageFileName.RAW_EMBED) elif PathManager.isfile(embeddings_path): serialized_embed_path = "" raw_embeddings_path = embeddings_path else: raise FileNotFoundError( f"{embeddings_path} not found. Can't load pretrained embeddings." ) if PathManager.isfile(serialized_embed_path): try: self.load_cached_embeddings(serialized_embed_path) except Exception: print( "Failed to load cached embeddings, loading the raw file." ) self.load_pretrained_embeddings( raw_embeddings_path, lowercase_tokens=lowercase_tokens) else: self.load_pretrained_embeddings( raw_embeddings_path, lowercase_tokens=lowercase_tokens) else: self.embed_vocab = [] # type: List[str] self.stoi = {} # type: Dict[str, int] self.embedding_vectors = None # type: torch.Tensor
def load_vocab(self, vocab_file, vocab_size, lowercase_tokens: bool = False): """ Loads items into a set from a file containing one item per line. Items are added to the set from top of the file to bottom. So, the items in the file should be ordered by a preference (if any), e.g., it makes sense to order tokens in descending order of frequency in corpus. Args: vocab_file (str): vocab file to load vocab_size (int): maximum tokens to load, will only load the first n if the actual vocab size is larger than this parameter lowercase_tokens (bool): if the tokens should be lowercased """ vocab: Set[str] = set() vocab_file = get_absolute_path(vocab_file) if PathManager.isfile(vocab_file): with PathManager.open(vocab_file, "r") as f: for i, line in enumerate(f): if vocab_size > 0 and len(vocab) == vocab_size: print(f"Read {i+1} items from {vocab_file} " f"to load vocab of size {vocab_size}. " f"Skipping rest of the file") break line = line.strip() vocab.add(line.lower() if lowercase_tokens else line) elif not vocab_file: print( f"{vocab_file} doesn't exist. Cannot load vocabulary from it") return vocab
def load(self, load_path: str): if not (load_path and PathManager.isfile(load_path)): raise ValueError(f"Invalid snapshot path{load_path}") with PathManager.open(load_path, "rb") as checkpoint_f: state = torch.load(checkpoint_f, map_location=lambda storage, loc: storage) return state
def add_include(path): """ Import tasks (and associated components) from the folder name. """ eprint("Including:", path) modules = glob.glob(os.path.join(path, "*.py")) all = [ os.path.basename(f)[:-3].replace("/", ".") for f in modules if PathManager.isfile(f) and not f.endswith("__init__.py") ] for mod_name in all: mod_path = path.replace("/", ".") + "." + mod_name eprint("... importing module:", mod_path) my_module = importlib.import_module(mod_path) for m in inspect.getmembers(my_module, inspect.isclass): if m[1].__module__ != mod_path: pass elif Task_Deprecated in m[1].__bases__ or NewTask in m[1].__bases__: eprint("... task:", m[1].__name__) register_tasks(m[1]) else: eprint("... importing:", m[1]) importlib.import_module(mod_path, m[1])
def prepare_task( config: PyTextConfig, dist_init_url: str = None, device_id: int = 0, rank: int = 0, world_size: int = 1, metric_channels: Optional[List[Channel]] = None, metadata: CommonMetadata = None, ) -> Tuple[Task_Deprecated, TrainingState]: if world_size > 1 and config.random_seed is None: msg = ( "Must set random seed when using world_size > 1, so that parameters have " "same initialization across workers." ) raise ValueError(msg) if rank == 0: print("\nParameters: {}\n".format(config), flush=True) _set_cuda(config.use_cuda_if_available, device_id, world_size) _set_fp16(config.use_fp16, rank) _set_distributed( rank, world_size, dist_init_url, device_id, config.gpu_streams_for_distributed_training, ) if config.random_seed is not None: set_random_seeds(config.random_seed, config.use_deterministic_cudnn) training_state = None if config.auto_resume_from_snapshot: # if there are existing checkpoints, resume from the latest one latest_snapshot_path = get_latest_checkpoint_path( os.path.dirname(config.save_snapshot_path) ) if latest_snapshot_path: config.load_snapshot_path = latest_snapshot_path if config.load_snapshot_path: assert PathManager.isfile(config.load_snapshot_path) if config.use_config_from_snapshot: task, _, training_state = load(config.load_snapshot_path) else: task, _, training_state = load( config.load_snapshot_path, overwrite_config=config ) if training_state: training_state.rank = rank else: task = create_task( config.task, metadata=metadata, rank=rank, world_size=world_size ) for mc in metric_channels or []: task.metric_reporter.add_channel(mc) return task, training_state
def load_analysis_from_path(self): assert PathManager.isfile( self.pre_analysis_path), "{} is not a file".format( self.pre_analysis_path) with PathManager.open(self.pre_analysis_path, "r") as fp: metric_dict = json.load(fp) return metric_dict
def __init__(self, config: Config, output_encoded_layers: bool, *args, **kwargs) -> None: super().__init__(config, output_encoded_layers=output_encoded_layers) # Load config config_file = os.path.join(config.bert_cpt_dir, "bert_config.json") bert_config = BertConfig.from_json_file(config_file) print("Bert model config {}".format(bert_config)) # Instantiate model. model = BertModel(bert_config) weights_path = os.path.join(config.bert_cpt_dir, "pytorch_model.bin") # load pre-trained weights if weights_path exists if config.load_weights and PathManager.isfile(weights_path): state_dict = torch.load(weights_path) missing_keys: List[str] = [] unexpected_keys: List[str] = [] error_msgs: List[str] = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=""): local_metadata = ({} if metadata is None else metadata.get( prefix[:-1], {})) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs, ) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") load(model, prefix="" if hasattr(model, "bert") else "bert.") if len(missing_keys) > 0: print( "Weights of {} not initialized from pretrained model: {}". format(model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: print( "Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) self.bert = model self.projection = (torch.nn.Linear(model.config.hidden_size, config.projection_dim) if config.projection_dim > 0 else None) log_class_usage(__class__)
def load(self, load_path: str, overwrite_config=None): """ Loads a checkpoint from disk. Args: load_path (str): the file path to load for checkpoint Returns: task (Task), config (PyTextConfig) and training_state (TrainingState) """ if not (load_path and PathManager.isfile(load_path)): raise ValueError(f"Invalid snapshot path{load_path}") print(f"Loading model from {load_path}") with PathManager.open(load_path, "rb") as checkpoint_f: return load_checkpoint(checkpoint_f, overwrite_config)
def reload_task_for_multi_export(config: PyTextConfig, latest_snapshot_path: str): assert PathManager.isfile(latest_snapshot_path) # reset flags before exporting new model in case changed by prev export cuda.CUDA_ENABLED = False precision.FP16_ENABLED = False if config.use_config_from_snapshot: task, _, _ = load(latest_snapshot_path) else: task, _, _ = load(latest_snapshot_path, overwrite_config=config) print(f"Loaded task from {latest_snapshot_path}!") return task
def reload_model_for_multi_export(config: PyTextConfig, ): latest_snapshot_path = config.save_snapshot_path print(f"Latest snapshot saved at {latest_snapshot_path}") if latest_snapshot_path: print("Reloading fresh model from snapshot path for multiple export") assert PathManager.isfile(latest_snapshot_path) if config.use_config_from_snapshot: task, _, _ = load(latest_snapshot_path) else: task, _, _ = load(latest_snapshot_path, overwrite_config=config) print(f"Loaded task from {latest_snapshot_path}!") return task.model else: print("Couldn't get latest snapshot path.") return None
def read_vocab(vocab_path): """ Create a dictionary from a vocabulary file. """ skipped = 0 assert PathManager.isfile(vocab_path), vocab_path word2id = {BOS_WORD: 0, EOS_WORD: 1, PAD_WORD: 2, UNK_WORD: 3} for i in range(SPECIAL_WORDS): word2id[SPECIAL_WORD % i] = 4 + i counts = {k: 0 for k in word2id.keys()} f = PathManager.open(vocab_path, "r", encoding="utf-8") for i, line in enumerate(f): if "\u2028" in line: skipped += 1 continue line = line.rstrip().split() if len(line) != 2: skipped += 1 continue assert len(line) == 2, (i, line) # assert line[0] not in word2id and line[1].isdigit(), (i, line) assert line[1].isdigit(), (i, line) if line[0] in word2id: skipped += 1 print("%s already in vocab" % line[0]) continue if not line[1].isdigit(): skipped += 1 print("Empty word at line %s with count %s" % (i, line)) continue # shift because of extra words word2id[line[0]] = 4 + SPECIAL_WORDS + i - skipped counts[line[0]] = int(line[1]) f.close() id2word = {v: k for k, v in word2id.items()} dico = Dictionary(id2word, word2id, counts) logger.info("Read %i words from the vocabulary file." % len(dico)) if skipped > 0: logger.warning("Skipped %i empty lines!" % skipped) return dico
def prepare_task( config: PyTextConfig, dist_init_url: str = None, device_id: int = 0, rank: int = 0, world_size: int = 1, metric_channels: Optional[List[Channel]] = None, metadata: CommonMetadata = None, ) -> Tuple[Task_Deprecated, TrainingState]: if rank == 0: print("\nParameters: {}\n".format(config), flush=True) _set_cuda(config.use_cuda_if_available, device_id, world_size) _set_fp16(config.use_fp16, rank) _set_distributed(rank, world_size, dist_init_url, device_id) if config.random_seed is not None: set_random_seeds(config.random_seed, config.use_deterministic_cudnn) training_state = None if config.load_snapshot_path and PathManager.isfile( config.load_snapshot_path): if config.use_config_from_snapshot: task, _, training_state = load(config.load_snapshot_path) else: task, _, training_state = load(config.load_snapshot_path, overwrite_config=config) if training_state: training_state.rank = rank else: task = create_task(config.task, metadata=metadata, rank=rank, world_size=world_size) for mc in metric_channels or []: task.metric_reporter.add_channel(mc) return task, training_state
def index_data(path, bin_path, dico): """ Index sentences with a dictionary. """ if bin_path is not None and PathManager.isfile(bin_path): print("Loading data from %s ..." % bin_path) data = torch.load(bin_path) assert dico == data["dico"] return data positions = [] sentences = [] unk_words = {} # index sentences f = PathManager.open(path, "r", encoding="utf-8") for i, line in enumerate(f): if i % 1000000 == 0 and i > 0: print(i) s = line.rstrip().split() # skip empty sentences if len(s) == 0: print("Empty sentence in line %i." % i) # index sentence words count_unk = 0 indexed = [] for w in s: word_id = dico.index(w, no_unk=False) # if we find a special word which is not an unknown word, # skip the sentence if 0 <= word_id < 4 + SPECIAL_WORDS and word_id != 3: logger.warning( 'Found unexpected special word "%s" (%i)!!' % (w, word_id) ) continue assert word_id >= 0 indexed.append(word_id) if word_id == dico.unk_index: unk_words[w] = unk_words.get(w, 0) + 1 count_unk += 1 # add sentence positions.append([len(sentences), len(sentences) + len(indexed)]) sentences.extend(indexed) sentences.append(1) # EOS index f.close() # tensorize data positions = np.int64(positions) if len(dico) < 1 << 16: sentences = np.uint16(sentences) elif len(dico) < 1 << 31: sentences = np.int32(sentences) else: raise Exception("Dictionary is too big.") assert sentences.min() >= 0 data = { "dico": dico, "positions": positions, "sentences": sentences, "unk_words": unk_words, } if bin_path is not None: print("Saving the data to %s ..." % bin_path) torch.save(data, bin_path, pickle_protocol=4) return data
def __init__( self, config: Config, output_encoded_layers: bool, *args, **kwargs ) -> None: super().__init__(config, output_encoded_layers=output_encoded_layers) # Load config config_file = os.path.join(config.bert_cpt_dir, "config.json") local_config_path = PathManager.get_local_path(config_file) bert_config = BertConfig.from_json_file(local_config_path) print("Bert model config {}".format(bert_config)) # Instantiate model. model = BertModel(bert_config) weights_path = os.path.join(config.bert_cpt_dir, "pytorch_model.bin") # load pre-trained weights if weights_path exists if config.load_weights and PathManager.isfile(weights_path): with PathManager.open(weights_path, "rb") as fd: state_dict = torch.load(fd) missing_keys: List[str] = [] unexpected_keys: List[str] = [] error_msgs: List[str] = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, "_metadata", None) for key in list(state_dict.keys()): new_key = None if key.endswith("LayerNorm.gamma"): # compatibility with v0.5 models new_key = key.replace("LayerNorm.gamma", "LayerNorm.weight") if key.endswith("LayerNorm.beta"): # compatibility with v0.5 models new_key = key.replace("LayerNorm.beta", "LayerNorm.bias") if new_key is not None: state_dict[new_key] = state_dict.pop(key) if metadata is not None: state_dict._metadata = metadata def load(module, prefix=""): local_metadata = ( {} if metadata is None else metadata.get(prefix[:-1], {}) ) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs, ) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") load(model, prefix="" if hasattr(model, "bert") else "bert.") if len(missing_keys) > 0: print( "Weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, missing_keys ) ) if len(unexpected_keys) > 0: print( "Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys ) ) self.bert = model log_class_usage(__class__)