示例#1
0
def is_global_rank_zero():
    """ Helper function to determine if the current process is global_rank 0 (the main process)
    """
    # Try to get the pytorch RANK env var
    # RANK is set by torch.distributed.launch
    rank = get_envint("RANK", None)
    if rank:
        return rank == 0

    # If not set by pytorch, we need to determine node_rank
    def get_node_rank():
        # Use an equivalent of pytorch lightning's determine_ddp_node_rank()
        node_rank = 0
        # First check if running on a slurm cluster
        # TODO: This check could probably be better
        num_slurm_tasks = get_envint("SLURM_NTASKS", 0)
        if num_slurm_tasks > 0:
            node_rank = get_envint("SLURM_NODEID", 0)
        else:
            node_rank_env = get_envint("NODE_RANK", None)
            group_rank = get_envint("GROUP_RANK", None)
            if group_rank:
                node_rank = group_rank
            # Take from NODE_RANK whenever available
            if node_rank_env:
                node_rank = node_rank_env
        return node_rank

    node_rank = get_node_rank()
    local_rank = get_envint("LOCAL_RANK", 0)
    return node_rank == 0 and local_rank == 0
示例#2
0
    def _define_logger(self):

        # Use double-checked locking to avoid taking lock unnecessarily.
        if self._logger is not None:
            return self._logger

        with self._logger_lock:
            try:
                self._logger = _logging.getLogger("nemo_logger")
                # By default, silence all loggers except the logger for rank 0
                self.remove_stream_handlers()
                if get_envbool(NEMO_ENV_VARNAME_TESTING, False):
                    old_factory = _logging.getLogRecordFactory()

                    def record_factory(*args, **kwargs):
                        record = old_factory(*args, **kwargs)
                        record.rank = get_envint("RANK", 0)
                        return record

                    _logging.setLogRecordFactory(record_factory)
                    self.add_stream_handlers(formatter=DebugNeMoFormatter)
                elif get_envint("RANK", 0) == 0:
                    self.add_stream_handlers()

            finally:
                self.set_verbosity(verbosity_level=Logger.INFO)

        self._logger.propagate = False
示例#3
0
 def get_node_rank():
     # Use an equivalent of pytorch lightning's determine_ddp_node_rank()
     node_rank = 0
     # First check if running on a slurm cluster
     # TODO: This check could probably be better
     num_slurm_tasks = get_envint("SLURM_NTASKS", 0)
     if num_slurm_tasks > 0:
         node_rank = get_envint("SLURM_NODEID", 0)
     else:
         node_rank_env = get_envint("NODE_RANK", None)
         group_rank = get_envint("GROUP_RANK", None)
         if group_rank:
             node_rank = group_rank
         # Take from NODE_RANK whenever available
         if node_rank_env:
             node_rank = node_rank_env
     return node_rank
示例#4
0
def is_global_rank_zero():
    """ Helper function to determine if the current process is global_rank 0 (the main process)
    """
    # Try to get the pytorch RANK env var
    # RANK is set by torch.distributed.launch
    rank = get_envint("RANK", None)
    if rank is not None:
        return rank == 0

    # Try to get the SLURM global rank env var
    # SLURM_PROCID is set by SLURM
    slurm_rank = get_envint("SLURM_PROCID", None)
    if slurm_rank is not None:
        return slurm_rank == 0

    # if neither pytorch and SLURM env vars are set
    # check NODE_RANK/GROUP_RANK and LOCAL_RANK env vars
    # asume global_rank is zero if undefined
    node_rank = get_envint("NODE_RANK", get_envint("GROUP_RANK", 0))
    local_rank = get_envint("LOCAL_RANK", 0)
    return node_rank == 0 and local_rank == 0
示例#5
0
    def _define_logger(self):

        # Use double-checked locking to avoid taking lock unnecessarily.
        if self._logger is not None:
            return self._logger

        with self._logger_lock:
            try:
                self._logger = _logging.getLogger("nemo_logger")
                # By default, silence all loggers except the logger for rank 0
                self.remove_stream_handlers()
                if get_envint("RANK", 0) == 0:
                    self.add_stream_handlers()

            finally:
                self.set_verbosity(verbosity_level=Logger.INFO)

        self._logger.propagate = False
示例#6
0
def dataset_to_ids(
    dataset,
    tokenizer,
    cache_ids=False,
    add_bos_eos=True,
    cache_data_per_node=False,
    use_cache=False,
):
    """
    Reads dataset from file line by line, tokenizes each line with tokenizer,
    and returns list of lists which corresponds to ids of tokenized strings.

    Args:
        dataset (str): path to dataset
        tokenizer: tokenizer to convert text into ids
        cache_ids (bool): if True, ids are saved to disk as pickle file
            with similar name (e.g., data.txt --> data.txt.pkl)
        add_bos_eos (bool): whether to add <s> and </s> symbols (e.g., for NMT)
        cache_data_per_node (bool): Cache data on local_rank 0. Use when there is not a shared-filesystem.
        use_cache (bool): Use cached ids if they exist.
    Returns:
        ids: list of ids which correspond to tokenized strings of the dataset
    """

    cached_ids_dataset = dataset + str(".pkl")
    if use_cache and os.path.isfile(cached_ids_dataset):
        logging.info("Loading cached tokenized dataset ...")
        ids = pickle.load(open(cached_ids_dataset, "rb"))
    else:
        logging.info(f"Tokenizing dataset {dataset}...")
        data = open(dataset, "rb").readlines()
        ids = []
        for sentence in tqdm(data, desc="Tokenizing sentence"):
            sent_ids = tokenizer.text_to_ids(sentence.decode("utf-8"))
            if add_bos_eos:
                sent_ids = [tokenizer.bos_id] + sent_ids + [tokenizer.eos_id]
            ids.append(sent_ids)
        if cache_ids and (not torch.distributed.is_initialized() or
                          (cache_data_per_node
                           and get_envint("LOCAL_RANK", 0) == 0)):
            logging.info("Caching tokenized dataset ...")
            pickle.dump(ids, open(cached_ids_dataset, "wb"))
    return ids
示例#7
0
 def record_factory(*args, **kwargs):
     record = old_factory(*args, **kwargs)
     record.rank = get_envint("RANK", 0)
     return record
示例#8
0
    def __init__(
        self,
        tokenizer: TokenizerSpec,
        input_file: str = None,
        queries: List[str] = None,
        max_seq_length: int = -1,
        num_samples: int = -1,
        shuffle: bool = False,
        use_cache: bool = False,
    ):
        if not input_file and not queries:
            raise ValueError(
                "Either input_file or queries should be passed to the text classification dataset."
            )

        if input_file and not os.path.exists(input_file):
            raise FileNotFoundError(
                f'Data file `{input_file}` not found! Each line of the data file should contain text sequences, where '
                f'words are separated with spaces and the label separated by [TAB] following this format: '
                f'[WORD][SPACE][WORD][SPACE][WORD][TAB][LABEL]')

        self.input_file = input_file
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.num_samples = num_samples
        self.shuffle = shuffle
        self.use_cache = use_cache
        self.vocab_size = self.tokenizer.vocab_size
        self.pad_id = tokenizer.pad_id

        self.features = None
        labels, all_sents = [], []
        if input_file:
            data_dir, filename = os.path.split(input_file)
            vocab_size = getattr(tokenizer, "vocab_size", 0)
            tokenizer_name = tokenizer.name
            cached_features_file = os.path.join(
                data_dir,
                f"cached_{filename}_{tokenizer_name}_{max_seq_length}_{vocab_size}_{num_samples}_{self.pad_id}_{shuffle}.pkl",
            )

            if get_envint("LOCAL_RANK", 0) == 0:
                if use_cache and os.path.exists(cached_features_file):
                    logging.warning(
                        f"Processing of {input_file} is skipped as caching is enabled and a cache file "
                        f"{cached_features_file} already exists.")
                    logging.warning(
                        f"You may need to delete the cache file if any of the processing parameters (eg. tokenizer) or "
                        f"the data are updated.")
                else:
                    with open(input_file, "r") as f:
                        lines = f.readlines(num_samples)
                        logging.info(
                            f'Read {len(lines)} examples from {input_file}.')

                        if shuffle:
                            random.shuffle(lines)

                        for index, line in enumerate(lines):
                            if index % 20000 == 0:
                                logging.debug(
                                    f"Processing line {index}/{len(lines)}")
                            line_splited = line.strip().split()
                            try:
                                label = int(line_splited[-1])
                            except ValueError:
                                logging.debug(f"Skipping line {line}")
                                continue
                            labels.append(label)
                            sent_words = line_splited[:-1]
                            all_sents.append(sent_words)
                    verbose = True

                    self.features = self.get_features(
                        all_sents=all_sents,
                        tokenizer=tokenizer,
                        max_seq_length=max_seq_length,
                        labels=labels,
                        verbose=verbose,
                    )
                    with open(cached_features_file, 'wb') as out_file:
                        pickle.dump(self.features,
                                    out_file,
                                    protocol=pickle.HIGHEST_PROTOCOL)
        else:
            for query in queries:
                all_sents.append(query.strip().split())
            labels = [-1] * len(all_sents)
            verbose = False
            self.features = self.get_features(all_sents=all_sents,
                                              tokenizer=tokenizer,
                                              max_seq_length=max_seq_length,
                                              labels=labels,
                                              verbose=verbose)

        # wait until the master process writes to the processed data files
        if torch.distributed.is_initialized():
            torch.distributed.barrier()

        if input_file:
            with open(cached_features_file, "rb") as input_file:
                self.features = pickle.load(input_file)