Beispiel #1
0
    def __init__(self,
                 embeddings_path: str = None,
                 lowercase_tokens: bool = True) -> None:
        if embeddings_path:
            if PathManager.isdir(embeddings_path):
                serialized_embed_path = os.path.join(
                    embeddings_path, PackageFileName.SERIALIZED_EMBED)
                raw_embeddings_path = os.path.join(embeddings_path,
                                                   PackageFileName.RAW_EMBED)
            elif PathManager.isfile(embeddings_path):
                serialized_embed_path = ""
                raw_embeddings_path = embeddings_path
            else:
                raise FileNotFoundError(
                    f"{embeddings_path} not found. Can't load pretrained embeddings."
                )

            if PathManager.isfile(serialized_embed_path):
                try:
                    self.load_cached_embeddings(serialized_embed_path)
                except Exception:
                    print(
                        "Failed to load cached embeddings, loading the raw file."
                    )
                    self.load_pretrained_embeddings(
                        raw_embeddings_path, lowercase_tokens=lowercase_tokens)
            else:
                self.load_pretrained_embeddings(
                    raw_embeddings_path, lowercase_tokens=lowercase_tokens)
        else:
            self.embed_vocab = []  # type: List[str]
            self.stoi = {}  # type: Dict[str, int]
            self.embedding_vectors = None  # type: torch.Tensor
Beispiel #2
0
    def load_vocab(self,
                   vocab_file,
                   vocab_size,
                   lowercase_tokens: bool = False):
        """
        Loads items into a set from a file containing one item per line.
        Items are added to the set from top of the file to bottom.
        So, the items in the file should be ordered by a preference (if any), e.g.,
        it makes sense to order tokens in descending order of frequency in corpus.

        Args:
            vocab_file (str): vocab file to load
            vocab_size (int): maximum tokens to load, will only load the first n if
                the actual vocab size is larger than this parameter
            lowercase_tokens (bool): if the tokens should be lowercased
        """
        vocab: Set[str] = set()
        vocab_file = get_absolute_path(vocab_file)
        if PathManager.isfile(vocab_file):
            with PathManager.open(vocab_file, "r") as f:
                for i, line in enumerate(f):
                    if vocab_size > 0 and len(vocab) == vocab_size:
                        print(f"Read {i+1} items from {vocab_file} "
                              f"to load vocab of size {vocab_size}. "
                              f"Skipping rest of the file")
                        break
                    line = line.strip()
                    vocab.add(line.lower() if lowercase_tokens else line)
        elif not vocab_file:
            print(
                f"{vocab_file} doesn't exist. Cannot load vocabulary from it")
        return vocab
Beispiel #3
0
 def load(self, load_path: str):
     if not (load_path and PathManager.isfile(load_path)):
         raise ValueError(f"Invalid snapshot path{load_path}")
     with PathManager.open(load_path, "rb") as checkpoint_f:
         state = torch.load(checkpoint_f,
                            map_location=lambda storage, loc: storage)
     return state
def add_include(path):
    """
    Import tasks (and associated components) from the folder name.
    """
    eprint("Including:", path)
    modules = glob.glob(os.path.join(path, "*.py"))
    all = [
        os.path.basename(f)[:-3].replace("/", ".")
        for f in modules
        if PathManager.isfile(f) and not f.endswith("__init__.py")
    ]
    for mod_name in all:
        mod_path = path.replace("/", ".") + "." + mod_name
        eprint("... importing module:", mod_path)
        my_module = importlib.import_module(mod_path)

        for m in inspect.getmembers(my_module, inspect.isclass):
            if m[1].__module__ != mod_path:
                pass
            elif Task_Deprecated in m[1].__bases__ or NewTask in m[1].__bases__:
                eprint("... task:", m[1].__name__)
                register_tasks(m[1])
            else:
                eprint("... importing:", m[1])
                importlib.import_module(mod_path, m[1])
Beispiel #5
0
def prepare_task(
    config: PyTextConfig,
    dist_init_url: str = None,
    device_id: int = 0,
    rank: int = 0,
    world_size: int = 1,
    metric_channels: Optional[List[Channel]] = None,
    metadata: CommonMetadata = None,
) -> Tuple[Task_Deprecated, TrainingState]:
    if world_size > 1 and config.random_seed is None:
        msg = (
            "Must set random seed when using world_size > 1, so that parameters have "
            "same initialization across workers."
        )
        raise ValueError(msg)

    if rank == 0:
        print("\nParameters: {}\n".format(config), flush=True)
    _set_cuda(config.use_cuda_if_available, device_id, world_size)
    _set_fp16(config.use_fp16, rank)
    _set_distributed(
        rank,
        world_size,
        dist_init_url,
        device_id,
        config.gpu_streams_for_distributed_training,
    )

    if config.random_seed is not None:
        set_random_seeds(config.random_seed, config.use_deterministic_cudnn)

    training_state = None

    if config.auto_resume_from_snapshot:
        # if there are existing checkpoints, resume from the latest one
        latest_snapshot_path = get_latest_checkpoint_path(
            os.path.dirname(config.save_snapshot_path)
        )
        if latest_snapshot_path:
            config.load_snapshot_path = latest_snapshot_path

    if config.load_snapshot_path:
        assert PathManager.isfile(config.load_snapshot_path)
        if config.use_config_from_snapshot:
            task, _, training_state = load(config.load_snapshot_path)
        else:
            task, _, training_state = load(
                config.load_snapshot_path, overwrite_config=config
            )
        if training_state:
            training_state.rank = rank
    else:
        task = create_task(
            config.task, metadata=metadata, rank=rank, world_size=world_size
        )

    for mc in metric_channels or []:
        task.metric_reporter.add_channel(mc)

    return task, training_state
Beispiel #6
0
    def load_analysis_from_path(self):
        assert PathManager.isfile(
            self.pre_analysis_path), "{} is not a file".format(
                self.pre_analysis_path)
        with PathManager.open(self.pre_analysis_path, "r") as fp:
            metric_dict = json.load(fp)

        return metric_dict
    def __init__(self, config: Config, output_encoded_layers: bool, *args,
                 **kwargs) -> None:
        super().__init__(config, output_encoded_layers=output_encoded_layers)
        # Load config
        config_file = os.path.join(config.bert_cpt_dir, "bert_config.json")
        bert_config = BertConfig.from_json_file(config_file)
        print("Bert model config {}".format(bert_config))
        # Instantiate model.
        model = BertModel(bert_config)
        weights_path = os.path.join(config.bert_cpt_dir, "pytorch_model.bin")
        # load pre-trained weights if weights_path exists
        if config.load_weights and PathManager.isfile(weights_path):
            state_dict = torch.load(weights_path)

            missing_keys: List[str] = []
            unexpected_keys: List[str] = []
            error_msgs: List[str] = []
            # copy state_dict so _load_from_state_dict can modify it
            metadata = getattr(state_dict, "_metadata", None)
            state_dict = state_dict.copy()
            if metadata is not None:
                state_dict._metadata = metadata

            def load(module, prefix=""):
                local_metadata = ({} if metadata is None else metadata.get(
                    prefix[:-1], {}))
                module._load_from_state_dict(
                    state_dict,
                    prefix,
                    local_metadata,
                    True,
                    missing_keys,
                    unexpected_keys,
                    error_msgs,
                )
                for name, child in module._modules.items():
                    if child is not None:
                        load(child, prefix + name + ".")

            load(model, prefix="" if hasattr(model, "bert") else "bert.")
            if len(missing_keys) > 0:
                print(
                    "Weights of {} not initialized from pretrained model: {}".
                    format(model.__class__.__name__, missing_keys))
            if len(unexpected_keys) > 0:
                print(
                    "Weights from pretrained model not used in {}: {}".format(
                        model.__class__.__name__, unexpected_keys))

        self.bert = model
        self.projection = (torch.nn.Linear(model.config.hidden_size,
                                           config.projection_dim)
                           if config.projection_dim > 0 else None)
        log_class_usage(__class__)
Beispiel #8
0
 def load(self, load_path: str, overwrite_config=None):
     """
     Loads a checkpoint from disk.
     Args:
         load_path (str): the file path to load for checkpoint
     Returns: task (Task), config (PyTextConfig) and training_state (TrainingState)
     """
     if not (load_path and PathManager.isfile(load_path)):
         raise ValueError(f"Invalid snapshot path{load_path}")
     print(f"Loading model from {load_path}")
     with PathManager.open(load_path, "rb") as checkpoint_f:
         return load_checkpoint(checkpoint_f, overwrite_config)
Beispiel #9
0
def reload_task_for_multi_export(config: PyTextConfig,
                                 latest_snapshot_path: str):
    assert PathManager.isfile(latest_snapshot_path)
    # reset flags before exporting new model in case changed by prev export
    cuda.CUDA_ENABLED = False
    precision.FP16_ENABLED = False
    if config.use_config_from_snapshot:
        task, _, _ = load(latest_snapshot_path)
    else:
        task, _, _ = load(latest_snapshot_path, overwrite_config=config)
    print(f"Loaded task from {latest_snapshot_path}!")
    return task
Beispiel #10
0
def reload_model_for_multi_export(config: PyTextConfig, ):
    latest_snapshot_path = config.save_snapshot_path
    print(f"Latest snapshot saved at {latest_snapshot_path}")
    if latest_snapshot_path:
        print("Reloading fresh model from snapshot path for multiple export")
        assert PathManager.isfile(latest_snapshot_path)
        if config.use_config_from_snapshot:
            task, _, _ = load(latest_snapshot_path)
        else:
            task, _, _ = load(latest_snapshot_path, overwrite_config=config)
        print(f"Loaded task from {latest_snapshot_path}!")
        return task.model
    else:
        print("Couldn't get latest snapshot path.")
        return None
Beispiel #11
0
 def read_vocab(vocab_path):
     """
     Create a dictionary from a vocabulary file.
     """
     skipped = 0
     assert PathManager.isfile(vocab_path), vocab_path
     word2id = {BOS_WORD: 0, EOS_WORD: 1, PAD_WORD: 2, UNK_WORD: 3}
     for i in range(SPECIAL_WORDS):
         word2id[SPECIAL_WORD % i] = 4 + i
     counts = {k: 0 for k in word2id.keys()}
     f = PathManager.open(vocab_path, "r", encoding="utf-8")
     for i, line in enumerate(f):
         if "\u2028" in line:
             skipped += 1
             continue
         line = line.rstrip().split()
         if len(line) != 2:
             skipped += 1
             continue
         assert len(line) == 2, (i, line)
         # assert line[0] not in word2id and line[1].isdigit(), (i, line)
         assert line[1].isdigit(), (i, line)
         if line[0] in word2id:
             skipped += 1
             print("%s already in vocab" % line[0])
             continue
         if not line[1].isdigit():
             skipped += 1
             print("Empty word at line %s with count %s" % (i, line))
             continue
         # shift because of extra words
         word2id[line[0]] = 4 + SPECIAL_WORDS + i - skipped
         counts[line[0]] = int(line[1])
     f.close()
     id2word = {v: k for k, v in word2id.items()}
     dico = Dictionary(id2word, word2id, counts)
     logger.info("Read %i words from the vocabulary file." % len(dico))
     if skipped > 0:
         logger.warning("Skipped %i empty lines!" % skipped)
     return dico
Beispiel #12
0
def prepare_task(
    config: PyTextConfig,
    dist_init_url: str = None,
    device_id: int = 0,
    rank: int = 0,
    world_size: int = 1,
    metric_channels: Optional[List[Channel]] = None,
    metadata: CommonMetadata = None,
) -> Tuple[Task_Deprecated, TrainingState]:
    if rank == 0:
        print("\nParameters: {}\n".format(config), flush=True)
    _set_cuda(config.use_cuda_if_available, device_id, world_size)
    _set_fp16(config.use_fp16, rank)
    _set_distributed(rank, world_size, dist_init_url, device_id)

    if config.random_seed is not None:
        set_random_seeds(config.random_seed, config.use_deterministic_cudnn)

    training_state = None
    if config.load_snapshot_path and PathManager.isfile(
            config.load_snapshot_path):
        if config.use_config_from_snapshot:
            task, _, training_state = load(config.load_snapshot_path)
        else:
            task, _, training_state = load(config.load_snapshot_path,
                                           overwrite_config=config)
        if training_state:
            training_state.rank = rank
    else:
        task = create_task(config.task,
                           metadata=metadata,
                           rank=rank,
                           world_size=world_size)

    for mc in metric_channels or []:
        task.metric_reporter.add_channel(mc)

    return task, training_state
Beispiel #13
0
    def index_data(path, bin_path, dico):
        """
        Index sentences with a dictionary.
        """
        if bin_path is not None and PathManager.isfile(bin_path):
            print("Loading data from %s ..." % bin_path)
            data = torch.load(bin_path)
            assert dico == data["dico"]
            return data

        positions = []
        sentences = []
        unk_words = {}

        # index sentences
        f = PathManager.open(path, "r", encoding="utf-8")
        for i, line in enumerate(f):
            if i % 1000000 == 0 and i > 0:
                print(i)
            s = line.rstrip().split()
            # skip empty sentences
            if len(s) == 0:
                print("Empty sentence in line %i." % i)
            # index sentence words
            count_unk = 0
            indexed = []
            for w in s:
                word_id = dico.index(w, no_unk=False)
                # if we find a special word which is not an unknown word,
                # skip the sentence
                if 0 <= word_id < 4 + SPECIAL_WORDS and word_id != 3:
                    logger.warning(
                        'Found unexpected special word "%s" (%i)!!' % (w, word_id)
                    )
                    continue
                assert word_id >= 0
                indexed.append(word_id)
                if word_id == dico.unk_index:
                    unk_words[w] = unk_words.get(w, 0) + 1
                    count_unk += 1
            # add sentence
            positions.append([len(sentences), len(sentences) + len(indexed)])
            sentences.extend(indexed)
            sentences.append(1)  # EOS index
        f.close()

        # tensorize data
        positions = np.int64(positions)
        if len(dico) < 1 << 16:
            sentences = np.uint16(sentences)
        elif len(dico) < 1 << 31:
            sentences = np.int32(sentences)
        else:
            raise Exception("Dictionary is too big.")
        assert sentences.min() >= 0
        data = {
            "dico": dico,
            "positions": positions,
            "sentences": sentences,
            "unk_words": unk_words,
        }
        if bin_path is not None:
            print("Saving the data to %s ..." % bin_path)
            torch.save(data, bin_path, pickle_protocol=4)

        return data
Beispiel #14
0
    def __init__(
        self, config: Config, output_encoded_layers: bool, *args, **kwargs
    ) -> None:
        super().__init__(config, output_encoded_layers=output_encoded_layers)
        # Load config
        config_file = os.path.join(config.bert_cpt_dir, "config.json")
        local_config_path = PathManager.get_local_path(config_file)
        bert_config = BertConfig.from_json_file(local_config_path)
        print("Bert model config {}".format(bert_config))
        # Instantiate model.
        model = BertModel(bert_config)
        weights_path = os.path.join(config.bert_cpt_dir, "pytorch_model.bin")
        # load pre-trained weights if weights_path exists
        if config.load_weights and PathManager.isfile(weights_path):
            with PathManager.open(weights_path, "rb") as fd:
                state_dict = torch.load(fd)

            missing_keys: List[str] = []
            unexpected_keys: List[str] = []
            error_msgs: List[str] = []
            # copy state_dict so _load_from_state_dict can modify it
            metadata = getattr(state_dict, "_metadata", None)
            for key in list(state_dict.keys()):
                new_key = None
                if key.endswith("LayerNorm.gamma"):  # compatibility with v0.5 models
                    new_key = key.replace("LayerNorm.gamma", "LayerNorm.weight")
                if key.endswith("LayerNorm.beta"):  # compatibility with v0.5 models
                    new_key = key.replace("LayerNorm.beta", "LayerNorm.bias")
                if new_key is not None:
                    state_dict[new_key] = state_dict.pop(key)

            if metadata is not None:
                state_dict._metadata = metadata

            def load(module, prefix=""):
                local_metadata = (
                    {} if metadata is None else metadata.get(prefix[:-1], {})
                )
                module._load_from_state_dict(
                    state_dict,
                    prefix,
                    local_metadata,
                    True,
                    missing_keys,
                    unexpected_keys,
                    error_msgs,
                )
                for name, child in module._modules.items():
                    if child is not None:
                        load(child, prefix + name + ".")

            load(model, prefix="" if hasattr(model, "bert") else "bert.")
            if len(missing_keys) > 0:
                print(
                    "Weights of {} not initialized from pretrained model: {}".format(
                        model.__class__.__name__, missing_keys
                    )
                )
            if len(unexpected_keys) > 0:
                print(
                    "Weights from pretrained model not used in {}: {}".format(
                        model.__class__.__name__, unexpected_keys
                    )
                )

        self.bert = model
        log_class_usage(__class__)