def is_global_rank_zero(): """ Helper function to determine if the current process is global_rank 0 (the main process) """ # Try to get the pytorch RANK env var # RANK is set by torch.distributed.launch rank = get_envint("RANK", None) if rank: return rank == 0 # If not set by pytorch, we need to determine node_rank def get_node_rank(): # Use an equivalent of pytorch lightning's determine_ddp_node_rank() node_rank = 0 # First check if running on a slurm cluster # TODO: This check could probably be better num_slurm_tasks = get_envint("SLURM_NTASKS", 0) if num_slurm_tasks > 0: node_rank = get_envint("SLURM_NODEID", 0) else: node_rank_env = get_envint("NODE_RANK", None) group_rank = get_envint("GROUP_RANK", None) if group_rank: node_rank = group_rank # Take from NODE_RANK whenever available if node_rank_env: node_rank = node_rank_env return node_rank node_rank = get_node_rank() local_rank = get_envint("LOCAL_RANK", 0) return node_rank == 0 and local_rank == 0
def _define_logger(self): # Use double-checked locking to avoid taking lock unnecessarily. if self._logger is not None: return self._logger with self._logger_lock: try: self._logger = _logging.getLogger("nemo_logger") # By default, silence all loggers except the logger for rank 0 self.remove_stream_handlers() if get_envbool(NEMO_ENV_VARNAME_TESTING, False): old_factory = _logging.getLogRecordFactory() def record_factory(*args, **kwargs): record = old_factory(*args, **kwargs) record.rank = get_envint("RANK", 0) return record _logging.setLogRecordFactory(record_factory) self.add_stream_handlers(formatter=DebugNeMoFormatter) elif get_envint("RANK", 0) == 0: self.add_stream_handlers() finally: self.set_verbosity(verbosity_level=Logger.INFO) self._logger.propagate = False
def get_node_rank(): # Use an equivalent of pytorch lightning's determine_ddp_node_rank() node_rank = 0 # First check if running on a slurm cluster # TODO: This check could probably be better num_slurm_tasks = get_envint("SLURM_NTASKS", 0) if num_slurm_tasks > 0: node_rank = get_envint("SLURM_NODEID", 0) else: node_rank_env = get_envint("NODE_RANK", None) group_rank = get_envint("GROUP_RANK", None) if group_rank: node_rank = group_rank # Take from NODE_RANK whenever available if node_rank_env: node_rank = node_rank_env return node_rank
def is_global_rank_zero(): """ Helper function to determine if the current process is global_rank 0 (the main process) """ # Try to get the pytorch RANK env var # RANK is set by torch.distributed.launch rank = get_envint("RANK", None) if rank is not None: return rank == 0 # Try to get the SLURM global rank env var # SLURM_PROCID is set by SLURM slurm_rank = get_envint("SLURM_PROCID", None) if slurm_rank is not None: return slurm_rank == 0 # if neither pytorch and SLURM env vars are set # check NODE_RANK/GROUP_RANK and LOCAL_RANK env vars # asume global_rank is zero if undefined node_rank = get_envint("NODE_RANK", get_envint("GROUP_RANK", 0)) local_rank = get_envint("LOCAL_RANK", 0) return node_rank == 0 and local_rank == 0
def _define_logger(self): # Use double-checked locking to avoid taking lock unnecessarily. if self._logger is not None: return self._logger with self._logger_lock: try: self._logger = _logging.getLogger("nemo_logger") # By default, silence all loggers except the logger for rank 0 self.remove_stream_handlers() if get_envint("RANK", 0) == 0: self.add_stream_handlers() finally: self.set_verbosity(verbosity_level=Logger.INFO) self._logger.propagate = False
def dataset_to_ids( dataset, tokenizer, cache_ids=False, add_bos_eos=True, cache_data_per_node=False, use_cache=False, ): """ Reads dataset from file line by line, tokenizes each line with tokenizer, and returns list of lists which corresponds to ids of tokenized strings. Args: dataset (str): path to dataset tokenizer: tokenizer to convert text into ids cache_ids (bool): if True, ids are saved to disk as pickle file with similar name (e.g., data.txt --> data.txt.pkl) add_bos_eos (bool): whether to add <s> and </s> symbols (e.g., for NMT) cache_data_per_node (bool): Cache data on local_rank 0. Use when there is not a shared-filesystem. use_cache (bool): Use cached ids if they exist. Returns: ids: list of ids which correspond to tokenized strings of the dataset """ cached_ids_dataset = dataset + str(".pkl") if use_cache and os.path.isfile(cached_ids_dataset): logging.info("Loading cached tokenized dataset ...") ids = pickle.load(open(cached_ids_dataset, "rb")) else: logging.info(f"Tokenizing dataset {dataset}...") data = open(dataset, "rb").readlines() ids = [] for sentence in tqdm(data, desc="Tokenizing sentence"): sent_ids = tokenizer.text_to_ids(sentence.decode("utf-8")) if add_bos_eos: sent_ids = [tokenizer.bos_id] + sent_ids + [tokenizer.eos_id] ids.append(sent_ids) if cache_ids and (not torch.distributed.is_initialized() or (cache_data_per_node and get_envint("LOCAL_RANK", 0) == 0)): logging.info("Caching tokenized dataset ...") pickle.dump(ids, open(cached_ids_dataset, "wb")) return ids
def record_factory(*args, **kwargs): record = old_factory(*args, **kwargs) record.rank = get_envint("RANK", 0) return record
def __init__( self, tokenizer: TokenizerSpec, input_file: str = None, queries: List[str] = None, max_seq_length: int = -1, num_samples: int = -1, shuffle: bool = False, use_cache: bool = False, ): if not input_file and not queries: raise ValueError( "Either input_file or queries should be passed to the text classification dataset." ) if input_file and not os.path.exists(input_file): raise FileNotFoundError( f'Data file `{input_file}` not found! Each line of the data file should contain text sequences, where ' f'words are separated with spaces and the label separated by [TAB] following this format: ' f'[WORD][SPACE][WORD][SPACE][WORD][TAB][LABEL]') self.input_file = input_file self.tokenizer = tokenizer self.max_seq_length = max_seq_length self.num_samples = num_samples self.shuffle = shuffle self.use_cache = use_cache self.vocab_size = self.tokenizer.vocab_size self.pad_id = tokenizer.pad_id self.features = None labels, all_sents = [], [] if input_file: data_dir, filename = os.path.split(input_file) vocab_size = getattr(tokenizer, "vocab_size", 0) tokenizer_name = tokenizer.name cached_features_file = os.path.join( data_dir, f"cached_{filename}_{tokenizer_name}_{max_seq_length}_{vocab_size}_{num_samples}_{self.pad_id}_{shuffle}.pkl", ) if get_envint("LOCAL_RANK", 0) == 0: if use_cache and os.path.exists(cached_features_file): logging.warning( f"Processing of {input_file} is skipped as caching is enabled and a cache file " f"{cached_features_file} already exists.") logging.warning( f"You may need to delete the cache file if any of the processing parameters (eg. tokenizer) or " f"the data are updated.") else: with open(input_file, "r") as f: lines = f.readlines(num_samples) logging.info( f'Read {len(lines)} examples from {input_file}.') if shuffle: random.shuffle(lines) for index, line in enumerate(lines): if index % 20000 == 0: logging.debug( f"Processing line {index}/{len(lines)}") line_splited = line.strip().split() try: label = int(line_splited[-1]) except ValueError: logging.debug(f"Skipping line {line}") continue labels.append(label) sent_words = line_splited[:-1] all_sents.append(sent_words) verbose = True self.features = self.get_features( all_sents=all_sents, tokenizer=tokenizer, max_seq_length=max_seq_length, labels=labels, verbose=verbose, ) with open(cached_features_file, 'wb') as out_file: pickle.dump(self.features, out_file, protocol=pickle.HIGHEST_PROTOCOL) else: for query in queries: all_sents.append(query.strip().split()) labels = [-1] * len(all_sents) verbose = False self.features = self.get_features(all_sents=all_sents, tokenizer=tokenizer, max_seq_length=max_seq_length, labels=labels, verbose=verbose) # wait until the master process writes to the processed data files if torch.distributed.is_initialized(): torch.distributed.barrier() if input_file: with open(cached_features_file, "rb") as input_file: self.features = pickle.load(input_file)