예제 #1
0
    def __init__(self, model_name, config, vocab_file, model_parallel_size=None, model_parallel_rank=None):

        super().__init__()

        self._model_parallel_size = model_parallel_size
        self._model_parallel_rank = model_parallel_rank
        self._restore_path = None
        self._app_state = None
        self._model_name = model_name

        if not os.path.exists(vocab_file):
            raise ValueError(f'Vocab file not found at {vocab_file}')

        # convert config to dictionary
        if isinstance(config, DictConfig):
            config = OmegaConf.to_container(config)
        config["vocab_file"] = vocab_file
        config['tokenizer_type'] = 'BertWordPieceLowerCase'
        config['lazy_mpu_init'] = True
        config['onnx_safe'] = True

        # if 'model_parallel_size' in config:
        if self._model_parallel_size is not None:
            app_state = AppState()
            self._app_state = app_state

            # must be set for model parallel megatron-lm
            os.environ["WORLD_SIZE"] = str(app_state.world_size)
            os.environ["RANK"] = str(self._model_parallel_rank)

            extra_args_provider = self._update_megatron_args(tensor_model_parallel_size=self._model_parallel_size)

        else:
            extra_args_provider = self._update_megatron_args()

        # configure globals for megatron
        set_pipeline_model_parallel_rank(0)  # pipeline model parallelism not implemented in NeMo
        set_pipeline_model_parallel_world_size(1)  # pipeline model parallelism not implemented in NeMo

        # Initialize part of Megatron global state that is needed for its constructor.
        # We set 'lazy_mpu_init' flag on to make Megatron do only the initialization that does not depend
        # on ddp be initialized yet (and we don't want Megatron to initialize DDP itself either)
        # and to return a hook for us to call after PTL has torch.distributed initialized.
        # (or if no PTL in case of inference - then we'll initialize torch.distributed)
        # We call and clear this hook on first call to forward()
        self._lazy_init_fn = initialize_megatron(
            extra_args_provider=extra_args_provider, args_defaults=config, ignore_unknown_args=True
        )

        # read Megatron arguments back
        args = get_args()
        logging.info(f'Megatron-lm argparse args: {args}')

        self.language_model, self._language_model_key = get_language_model(
            attention_mask_func=bert_attention_mask_func, num_tokentypes=2, add_pooler=False
        )

        self.config = OmegaConf.create(config)
        # key used for checkpoints
        self._hidden_size = self.language_model.hidden_size
예제 #2
0
    def __init__(self, model_name, config, vocab_file, model_parallel_size=None):

        super().__init__()

        self._model_parallel_size = model_parallel_size
        self._restore_path = None
        self._app_state = None

        if not os.path.exists(vocab_file):
            raise ValueError(f'Vocab file not found at {vocab_file}')

        config["vocab_file"] = vocab_file
        config['tokenizer_type'] = 'BertWordPieceLowerCase'
        config['lazy_mpu_init'] = True
        config['onnx_safe'] = True

        # if 'model_parallel_size' in config:
        if self._model_parallel_size is not None:
            app_state = AppState()
            self._app_state = app_state

            # must be set for model parallel megatron-lm
            os.environ["WORLD_SIZE"] = str(app_state.world_size)

            # used to set model_parallel_size in megatron-lm argparser
            def _update_model_parallel_arg(parser):
                parser.set_defaults(model_parallel_size=self._model_parallel_size)
                return parser

            extra_args_provider = _update_model_parallel_arg
        else:
            extra_args_provider = None

        # Initialize part of Megatron global state that is needed for its constructor.
        # We set 'lazy_mpu_init' flag on to make Megatron do only the initialization that does not depend
        # on ddp be initialized yet (and we don't want Megatron to initialize DDP itself either)
        # and to return a hook for us to call after PTL has torch.distributed initialized.
        # We call this hook during .forward
        # TODO: can we call this hook using the PTL hook .setup()
        self._lazy_init_fn = initialize_megatron(
            extra_args_provider=extra_args_provider, args_defaults=config, ignore_unknown_args=True
        )

        # read Megatron arguments back
        args = get_args()
        logging.info(f'Megatron-lm argparse args: {args}')

        self.language_model, self._language_model_key = get_language_model(
            attention_mask_func=bert_attention_mask_func, num_tokentypes=2, add_pooler=False
        )

        self.config = OmegaConf.create(config)
        # key used for checkpoints
        self._hidden_size = self.language_model.hidden_size