def __init__(self, model_name, config, vocab_file, model_parallel_size=None, model_parallel_rank=None): super().__init__() self._model_parallel_size = model_parallel_size self._model_parallel_rank = model_parallel_rank self._restore_path = None self._app_state = None self._model_name = model_name if not os.path.exists(vocab_file): raise ValueError(f'Vocab file not found at {vocab_file}') # convert config to dictionary if isinstance(config, DictConfig): config = OmegaConf.to_container(config) config["vocab_file"] = vocab_file config['tokenizer_type'] = 'BertWordPieceLowerCase' config['lazy_mpu_init'] = True config['onnx_safe'] = True # if 'model_parallel_size' in config: if self._model_parallel_size is not None: app_state = AppState() self._app_state = app_state # must be set for model parallel megatron-lm os.environ["WORLD_SIZE"] = str(app_state.world_size) os.environ["RANK"] = str(self._model_parallel_rank) extra_args_provider = self._update_megatron_args(tensor_model_parallel_size=self._model_parallel_size) else: extra_args_provider = self._update_megatron_args() # configure globals for megatron set_pipeline_model_parallel_rank(0) # pipeline model parallelism not implemented in NeMo set_pipeline_model_parallel_world_size(1) # pipeline model parallelism not implemented in NeMo # Initialize part of Megatron global state that is needed for its constructor. # We set 'lazy_mpu_init' flag on to make Megatron do only the initialization that does not depend # on ddp be initialized yet (and we don't want Megatron to initialize DDP itself either) # and to return a hook for us to call after PTL has torch.distributed initialized. # (or if no PTL in case of inference - then we'll initialize torch.distributed) # We call and clear this hook on first call to forward() self._lazy_init_fn = initialize_megatron( extra_args_provider=extra_args_provider, args_defaults=config, ignore_unknown_args=True ) # read Megatron arguments back args = get_args() logging.info(f'Megatron-lm argparse args: {args}') self.language_model, self._language_model_key = get_language_model( attention_mask_func=bert_attention_mask_func, num_tokentypes=2, add_pooler=False ) self.config = OmegaConf.create(config) # key used for checkpoints self._hidden_size = self.language_model.hidden_size
def __init__(self, model_name, config, vocab_file, model_parallel_size=None): super().__init__() self._model_parallel_size = model_parallel_size self._restore_path = None self._app_state = None if not os.path.exists(vocab_file): raise ValueError(f'Vocab file not found at {vocab_file}') config["vocab_file"] = vocab_file config['tokenizer_type'] = 'BertWordPieceLowerCase' config['lazy_mpu_init'] = True config['onnx_safe'] = True # if 'model_parallel_size' in config: if self._model_parallel_size is not None: app_state = AppState() self._app_state = app_state # must be set for model parallel megatron-lm os.environ["WORLD_SIZE"] = str(app_state.world_size) # used to set model_parallel_size in megatron-lm argparser def _update_model_parallel_arg(parser): parser.set_defaults(model_parallel_size=self._model_parallel_size) return parser extra_args_provider = _update_model_parallel_arg else: extra_args_provider = None # Initialize part of Megatron global state that is needed for its constructor. # We set 'lazy_mpu_init' flag on to make Megatron do only the initialization that does not depend # on ddp be initialized yet (and we don't want Megatron to initialize DDP itself either) # and to return a hook for us to call after PTL has torch.distributed initialized. # We call this hook during .forward # TODO: can we call this hook using the PTL hook .setup() self._lazy_init_fn = initialize_megatron( extra_args_provider=extra_args_provider, args_defaults=config, ignore_unknown_args=True ) # read Megatron arguments back args = get_args() logging.info(f'Megatron-lm argparse args: {args}') self.language_model, self._language_model_key = get_language_model( attention_mask_func=bert_attention_mask_func, num_tokentypes=2, add_pooler=False ) self.config = OmegaConf.create(config) # key used for checkpoints self._hidden_size = self.language_model.hidden_size