コード例 #1
0
ファイル: nlp_overrides.py プロジェクト: NVIDIA/NeMo
 def remove_checkpoint(self, filepath: _PATH) -> None:
     app_state = AppState()
     # PTL override to accomodate model parallel checkpoints
     filepath = inject_model_parallel_rank(filepath)
     if self.is_global_zero or app_state.data_parallel_rank == 0:
         logging.info(f'Removing checkpoint: {filepath}')
         self.checkpoint_io.remove_checkpoint(filepath)
コード例 #2
0
def convert(local_rank, rank, world_size, args):

    app_state = AppState()
    app_state.data_parallel_rank = 0
    num_nodes = world_size // args.gpus_per_node
    if args.bcp:
        trainer = Trainer(devices=args.gpus_per_node,
                          num_nodes=num_nodes,
                          accelerator='gpu',
                          plugins=[TorchElasticEnvironment()])
    else:
        trainer = Trainer(devices=args.gpus_per_node,
                          num_nodes=num_nodes,
                          accelerator='gpu')

    app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size
    app_state.tensor_model_parallel_size = args.tensor_model_parallel_size
    app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size

    parallel_state.initialize_model_parallel(
        tensor_model_parallel_size_=app_state.tensor_model_parallel_size,
        pipeline_model_parallel_size_=app_state.pipeline_model_parallel_size,
    )

    app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank(
    )
    app_state.tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank(
    )

    # inject model parallel rank
    checkpoint_path = inject_model_parallel_rank(
        os.path.join(args.checkpoint_folder, args.checkpoint_name))

    logging.info(
        f'rank: {rank}, local_rank: {local_rank}, is loading checkpoint: {checkpoint_path} for tp_rank: {app_state.tensor_model_parallel_rank} and pp_rank: {app_state.pipeline_model_parallel_rank}'
    )

    if args.model_type == 'gpt':
        model = MegatronGPTModel.load_from_checkpoint(
            checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
    elif args.model_type == 'bert':
        model = MegatronBertModel.load_from_checkpoint(
            checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
    elif args.model_type == 't5':
        model = MegatronT5Model.load_from_checkpoint(
            checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
    elif args.model_type == 'nmt':
        model = MegatronNMTModel.load_from_checkpoint(
            checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
    model._save_restore_connector = NLPSaveRestoreConnector()

    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    model.save_to(args.nemo_file_path)

    logging.info(f'NeMo model saved to: {args.nemo_file_path}')
コード例 #3
0
def main(cfg) -> None:

    # trainer required for restoring model parallel models
    trainer = Trainer(plugins=NLPDDPPlugin(), **cfg.trainer)
    assert (
        cfg.trainer.devices * cfg.trainer.num_nodes
        == cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size
    ), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size"

    app_state = AppState()
    app_state.model_parallel_size = cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size
    (
        app_state.tensor_model_parallel_rank,
        app_state.pipeline_model_parallel_rank,
        app_state.model_parallel_size,
        app_state.data_parallel_size,
        app_state.pipeline_model_parallel_split_rank,
    ) = fake_initialize_model_parallel(
        world_size=app_state.model_parallel_size,
        rank=trainer.global_rank,
        tensor_model_parallel_size_=cfg.tensor_model_parallel_size,
        pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size,
        pipeline_model_parallel_split_rank_=cfg.pipeline_model_parallel_split_rank,
    )

    if cfg.model_file is not None:
        if not os.path.exists(cfg.model_file):
            raise ValueError(f"Model file {cfg.model_file} does not exist")
        model = MegatronNMTModel.restore_from(
            restore_path=cfg.model_file, trainer=trainer, save_restore_connector=NLPSaveRestoreConnector(),
        )
    elif cfg.checkpoint_dir is not None:
        checkpoint_path = inject_model_parallel_rank(os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name))
        model = MegatronNMTModel.load_from_checkpoint(checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer)
    else:
        raise ValueError("need at least a nemo file or checkpoint dir")

    model.freeze()

    logging.info(f"Translating: {cfg.srctext}")
    src_text = []
    translations = []
    with open(cfg.srctext, 'r') as src_f, open(cfg.tgtout, 'w') as tgt_f:
        for line in src_f:
            src_text.append(line.strip())
            if len(src_text) == cfg.batch_size:
                translations = model.translate(
                    text=src_text, source_lang=cfg.source_lang, target_lang=cfg.target_lang,
                )
                for translation in translations:
                    tgt_f.write(translation + "\n")
                src_text = []
        if len(src_text) > 0:
            translations = model.translate(text=src_text, source_lang=cfg.source_lang, target_lang=cfg.target_lang,)
            for translation in translations:
                tgt_f.write(translation + "\n")
コード例 #4
0
ファイル: nlp_overrides.py プロジェクト: NVIDIA/NeMo
 def save_checkpoint(self,
                     checkpoint: Dict[str, Any],
                     filepath: _PATH,
                     storage_options: Optional[Any] = None) -> None:
     app_state = AppState()
     # PTL override to accomodate model parallel checkpoints
     filepath = inject_model_parallel_rank(filepath)
     if self.is_global_zero or app_state.data_parallel_rank == 0:
         self.checkpoint_io.save_checkpoint(checkpoint,
                                            filepath,
                                            storage_options=storage_options)
コード例 #5
0
    def _del_model_without_trainer(self, filepath: str) -> None:
        app_state = AppState()
        if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
            # filepath needs to be updated to include mp_rank
            filepath = inject_model_parallel_rank(filepath)

        # each model parallel rank needs to remove its model
        if is_global_rank_zero() or (app_state.model_parallel_size is not None and app_state.data_parallel_rank == 0):
            try:
                self._fs.rm(filepath)
                logging.info(f"Removed checkpoint: {filepath}")
            except:
                logging.info(f"Tried to remove checkpoint: {filepath} but failed.")
コード例 #6
0
def main(cfg) -> None:

    # trainer required for restoring model parallel models
    trainer = Trainer(plugins=NLPDDPPlugin(), **cfg.trainer)
    assert (
        cfg.trainer.devices *
        cfg.trainer.num_nodes == cfg.tensor_model_parallel_size *
        cfg.pipeline_model_parallel_size
    ), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size"

    # Load prompt tuned model, virtual_prompt_model_file must be provided in config
    if cfg.get('virtual_prompt_model_file', None) is not None:

        # Update frozen GPT model path in case it has changed
        prompt_learning_cfg = MegatronGPTPromptLearningModel.restore_from(
            cfg.virtual_prompt_model_file, trainer=trainer, return_config=True)
        with open_dict(prompt_learning_cfg):
            prompt_learning_cfg.language_model_path = cfg.gpt_model_file

        # Now load prompt learning model with frozen gpt model base
        model = MegatronGPTPromptLearningModel.restore_from(
            restore_path=cfg.virtual_prompt_model_file,
            trainer=trainer,
            override_config_path=prompt_learning_cfg)

    # Or load regular GPT model
    elif cfg.gpt_model_file:
        model = MegatronGPTModel.restore_from(restore_path=cfg.gpt_model_file,
                                              trainer=trainer)
    elif cfg.checkpoint_dir:
        app_state = AppState()
        if cfg.tensor_model_parallel_size > 1 or cfg.pipeline_model_parallel_size > 1:
            app_state.model_parallel_size = cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size
            (
                app_state.tensor_model_parallel_rank,
                app_state.pipeline_model_parallel_rank,
                app_state.model_parallel_size,
                app_state.data_parallel_size,
                app_state.pipeline_model_parallel_split_rank,
            ) = fake_initialize_model_parallel(
                world_size=app_state.model_parallel_size,
                rank=trainer.global_rank,
                tensor_model_parallel_size_=cfg.tensor_model_parallel_size,
                pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size,
                pipeline_model_parallel_split_rank_=cfg.
                pipeline_model_parallel_split_rank,
            )
        checkpoint_path = inject_model_parallel_rank(
            os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name))
        model = MegatronGPTModel.load_from_checkpoint(
            checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer)
    else:
        raise ValueError("need at least a nemo file or checkpoint dir")

    model.freeze()

    # Have to turn off activations_checkpoint_method for inference
    try:
        model.model.language_model.encoder.activations_checkpoint_method = None
    except AttributeError:
        pass

    try:
        model.frozen_model.language_model.encoder.activations_checkpoint_method = None
    except AttributeError:
        pass

    length_params: LengthParam = {
        "max_length": cfg.inference.tokens_to_generate,
        "min_length": cfg.inference.min_tokens_to_generate,
    }

    sampling_params: SamplingParam = {
        "use_greedy": cfg.inference.greedy,
        "temperature": cfg.inference.temperature,
        "top_k": cfg.inference.top_k,
        "top_p": cfg.inference.top_p,
        "repetition_penalty": cfg.inference.repetition_penalty,
        "add_BOS": cfg.inference.add_BOS,
        "all_probs": cfg.inference.all_probs,
        "compute_logprob": cfg.inference.compute_logprob,
    }

    # First method of running text generation, call model.generate method
    response = model.generate(inputs=OmegaConf.to_container(cfg.prompts),
                              length_params=length_params,
                              sampling_params=sampling_params)

    print("***************************")
    print(response)
    print("***************************")

    # Second method of running text generation, call trainer.predict
    collate_fn = None
    if cfg.get('virtual_prompt_model', False):
        collate_fn = lambda x: list(x)

    ds = RequestDataSet(OmegaConf.to_container(cfg.prompts))
    request_dl = DataLoader(dataset=ds, collate_fn=collate_fn, batch_size=2)

    config = OmegaConf.to_container(cfg.inference)
    model.set_inference_config(config)
    response = trainer.predict(model, request_dl)

    print("***************************")
    print(response)
    print("***************************")

    # Third method of running text generation, use inference server
    if cfg.server:
        if parallel_state.is_pipeline_first_stage(
        ) and parallel_state.get_tensor_model_parallel_rank() == 0:
            server = MegatronServer(model.cuda())
            server.run("0.0.0.0", port=cfg.port)

        while True:
            choice = torch.cuda.LongTensor(1)
            torch.distributed.broadcast(choice, 0)
            if choice[0].item() == 0:
                generate(model.cuda())
コード例 #7
0
 def _inject_model_parallel_rank_for_ckpt(self, dirname, basename):
     model_weights = os.path.join(dirname, basename)
     model_weights = inject_model_parallel_rank(model_weights)
     return model_weights
コード例 #8
0
def convert(local_rank, rank, world_size, args):

    app_state = AppState()
    app_state.data_parallel_rank = 0
    tensor_model_parallel_size = args.tensor_model_parallel_size
    num_nodes = world_size // args.gpus_per_node
    pipeline_model_parallel_size = world_size // args.tensor_model_parallel_size
    assert args.pipeline_model_parallel_size == pipeline_model_parallel_size

    trainer = Trainer(devices=args.gpus_per_node,
                      accelerator='gpu',
                      num_nodes=num_nodes)

    app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size
    app_state.tensor_model_parallel_size = args.tensor_model_parallel_size
    app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size

    parallel_state.initialize_model_parallel(
        tensor_model_parallel_size_=app_state.tensor_model_parallel_size,
        pipeline_model_parallel_size_=app_state.pipeline_model_parallel_size,
    )

    app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank(
    )
    app_state.tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank(
    )

    pipeline_rank = rank // tensor_model_parallel_size
    tensor_rank = app_state.tensor_model_parallel_rank
    assert pipeline_rank == app_state.pipeline_model_parallel_rank

    if tensor_model_parallel_size is not None and tensor_model_parallel_size > 1 and pipeline_model_parallel_size == 1:
        # inject model parallel rank
        checkpoint_path = os.path.join(args.checkpoint_folder,
                                       f'mp_rank_{tensor_rank:02d}',
                                       args.checkpoint_name)
    elif tensor_model_parallel_size is not None and pipeline_model_parallel_size > 1:
        checkpoint_path = os.path.join(
            args.checkpoint_folder,
            f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}',
            args.checkpoint_name)
    else:
        checkpoint_path = os.path.join(args.checkpoint_folder,
                                       args.checkpoint_name)
    logging.info(f"loading checkpoint {checkpoint_path}")

    if args.model_type == 'gpt':
        ## this dictionary is used to rename the model parameters
        name_translate = {}
        name_translate['transformer'] = 'encoder'
        name_translate['.attention.'] = '.self_attention.'
        # nemo megatron doesn't have _for_head key
        name_translate['word_embeddings_for_head'] = 'word_embeddings'
        checkpoint, consumed, steps, version = load_from_checkpoint(
            MegatronGPTModel,
            checkpoint_path,
            hparams_file=args.hparams_file,
            trainer=trainer,
            translator=name_translate,
            strict=False,
        )
    elif args.model_type == 'bert':
        ## this dictionary is used to rename the model parameters
        name_translate = {}
        name_translate['transformer'] = 'encoder'
        name_translate['.attention.'] = '.self_attention.'
        # nemo megatron doesn't have _for_head key
        name_translate['word_embeddings_for_head'] = 'word_embeddings'
        checkpoint, consumed, steps, version = load_from_checkpoint(
            MegatronBertModel,
            checkpoint_path,
            hparams_file=args.hparams_file,
            trainer=trainer,
            translator=name_translate,
            strict=False,
        )
    else:
        raise NotImplemented("{} is not supported".format(args.model_type))

    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    if args.output_ckpt_file_path:
        filepath = args.output_ckpt_file_path
        base_dir = pathlib.Path(filepath).parent
        filename_str = pathlib.Path(filepath).name
        suffix = '.ckpt'
        content = {}
        if consumed is not None:
            content['consumed'] = consumed
        else:
            content['consumed'] = 0
        if steps is not None:
            content['steps'] = steps
        else:
            content['steps'] = 0
        filename = filename_str.format(**content) + suffix
        checkpoint_path_output = inject_model_parallel_rank(
            os.path.join(base_dir, filename))
        trainer.accelerator.training_type_plugin.checkpoint_io.save_checkpoint(
            checkpoint, checkpoint_path_output)
        logging.info(
            f'NeMo model checkpoint files saved to: {args.output_ckpt_file_path}'
        )

    if args.nemo_file_path:
        if args.model_type == 'gpt':
            model = load_model(MegatronGPTModel,
                               checkpoint,
                               strict=False,
                               trainer=trainer)
        elif args.model_type == 'bert':
            model = load_model(MegatronBertModel,
                               checkpoint,
                               strict=False,
                               trainer=trainer)
        else:
            raise NotImplemented("{} is not supported".format(args.model_type))

        # verify tensor parallel rank id and pipeline parallel rank id matches
        assert app_state.data_parallel_size == 1
        assert app_state.tensor_model_parallel_size == tensor_model_parallel_size
        assert app_state.tensor_model_parallel_rank == tensor_rank
        assert app_state.pipeline_model_parallel_size == pipeline_model_parallel_size
        assert app_state.pipeline_model_parallel_rank == pipeline_rank
        model._save_restore_connector = NLPSaveRestoreConnector()
        model.save_to(args.nemo_file_path)
        logging.info(f'NeMo model saved to: {args.nemo_file_path}')
コード例 #9
0
ファイル: nlp_overrides.py プロジェクト: NVIDIA/NeMo
 def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
     """ PTL override to accomodate model parallel checkpoints """
     # TODO: move to CheckpointIO
     torch.cuda.empty_cache()
     checkpoint_path = inject_model_parallel_rank(checkpoint_path)
     return self.checkpoint_io.load_checkpoint(checkpoint_path)
コード例 #10
0
ファイル: megatron_gpt_eval.py プロジェクト: quuhua911/NeMo
def main():
    parser = ArgumentParser()

    # args for loading the model, either from .nemo file or from PTL checkpoint
    parser.add_argument("--model_file",
                        type=str,
                        default="",
                        required=False,
                        help="Pass path to model's .nemo file")
    parser.add_argument(
        "--checkpoint_dir",
        type=str,
        default=None,
        required=False,
        help=
        "If not using a .nemo file. Path to PTL checkpoints saved during training. Ex: /raid/nemo_experiments/megatron_gpt/checkpoints",
    )
    parser.add_argument(
        "--checkpoint_name",
        type=str,
        default=None,
        required=False,
        help=
        "If not using a .nemo file. Name of checkpoint to be used. Ex: megatron_gpt--val_loss=6.34-step=649-last.ckpt",
    )

    parser.add_argument(
        "--hparams_file",
        type=str,
        default=None,
        required=False,
        help=
        "If not using a .nemo file. Path to config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml",
    )
    parser.add_argument("--tensor_model_parallel_size",
                        type=int,
                        default=1,
                        required=False,
                        help="Needed if not using a .nemo file")
    parser.add_argument(
        "--pipeline_model_parallel_size",
        type=int,
        default=1,
        required=False,
        help="Needed if not using a .nemo file",
    )

    # PTL Trainer args
    parser.add_argument("--devices",
                        default=1,
                        type=int,
                        help="PyTorch Lightning Trainer devices flag")
    parser.add_argument("--num_nodes",
                        default=1,
                        type=int,
                        help="PyTorch Lightning Trainer num_nodes flag")
    parser.add_argument("--precision",
                        default=16,
                        help="PyTorch Lightning Trainer precision flag")

    # evaluation args
    parser.add_argument("--path_to_file",
                        type=str,
                        default="",
                        required=False,
                        help="Path to file with prompts (a text to complete)")
    parser.add_argument("--prompt",
                        type=str,
                        default="",
                        required=False,
                        help="Prompt for the model (a text to complete)")
    parser.add_argument("--use_soft_prompts",
                        action="store_true",
                        help="Use model's existing soft prompts")
    parser.add_argument("--prompt_tag",
                        type=str,
                        default="",
                        required=False,
                        help="Prompt tag string for task specific soft prompt")
    parser.add_argument("--tokens_to_generate",
                        type=int,
                        default="1",
                        required=False,
                        help="How many tokens to add to prompt")
    parser.add_argument(
        "--stop_after_sentence",
        type=bool,
        default="True",
        required=False,
        help=
        "True/False: whether to stop after full sentence has been generated.",
    )
    parser.add_argument("--batch_size",
                        default=1,
                        type=int,
                        required=False,
                        help="Evaluation batch_size")
    parser.add_argument("--compute_logprobs",
                        type=bool,
                        default=False,
                        required=False,
                        help="Method for logprobs computation")

    args = parser.parse_args()

    assert (
        args.devices * args.num_nodes == args.tensor_model_parallel_size *
        args.pipeline_model_parallel_size
    ), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size"

    if args.model_file and args.checkpoint_dir:
        raise ValueError(
            "Only one of model_file or checkpoint_dir should be used")

    # cast precision to int if 32 or 16
    if args.precision in ["32", "16"]:
        args.precision = int(float(args.precision))

    # trainer required for restoring model parallel models
    trainer = Trainer(
        plugins=[NLPDDPPlugin()],
        devices=args.devices,
        num_nodes=args.num_nodes,
        accelerator='gpu',
        precision=args.precision,
    )

    if args.model_file:
        model = MegatronGPTModel.restore_from(restore_path=args.model_file,
                                              trainer=trainer)
    elif args.checkpoint_dir:
        app_state = AppState()
        if args.tensor_model_parallel_size > 1 or args.pipeline_model_parallel_size > 1:
            app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size
            app_state.tensor_model_parallel_size = args.tensor_model_parallel_size
            app_state.model_parallel_size = args.tensor_model_parallel_size * args.pipeline_model_parallel_size
            (
                app_state.tensor_model_parallel_rank,
                app_state.pipeline_model_parallel_rank,
                app_state.model_parallel_size,
                _,
            ) = fake_initialize_model_parallel(
                world_size=app_state.model_parallel_size,
                rank=trainer.global_rank,
                tensor_model_parallel_size_=app_state.
                tensor_model_parallel_size,
                pipeline_model_parallel_size_=app_state.
                pipeline_model_parallel_size,
            )
        # inject model parallel rank
        checkpoint_path = inject_model_parallel_rank(
            os.path.join(args.checkpoint_dir, args.checkpoint_name))

        model = MegatronGPTModel.load_from_checkpoint(
            checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)

    model.freeze()

    def pad_collate(batch):
        tokens, tokens_to_generate = batch[0]['data'], batch[0][
            'tokens_to_generate']
        compute_logprobs = batch[0]['compute_logprobs']
        lens = [len(token) for token in tokens]

        tokens_pad = pad_sequence(tokens,
                                  batch_first=False,
                                  padding_value=50256)
        data = []

        if 'prompt_tags' in batch[0]:
            # Keep track of soft prompt tags
            prompt_tags = batch[0]['prompt_tags']

            for token, lenn, prompt_tag in zip(tokens_pad.T, lens,
                                               prompt_tags):
                data.append((token, lenn, tokens_to_generate, compute_logprobs,
                             prompt_tag))
        else:
            for token, lenn in zip(tokens_pad.T, lens):
                data.append(
                    (token, lenn, tokens_to_generate, compute_logprobs))

        return data

    # defining type of request
    if args.path_to_file != "":
        request = []
        prompts = open(args.path_to_file, 'r', encoding='utf-8')

        for prompt in prompts.readlines():
            prompt = prompt.split('\n')[0]

            if args.use_soft_prompts and model.use_soft_prompts:
                prompt = json.loads(prompt)

            request.append(prompt)

        dataset = GPTRequestDataset(request, model.tokenizer,
                                    args.tokens_to_generate,
                                    args.compute_logprobs)
        request_dl = DataLoader(dataset=pad_collate(dataset),
                                batch_size=int(args.batch_size))

    else:
        if args.use_soft_prompts and model.use_soft_prompts:
            request = [{'prompt_tag': args.prompt_tag, 'text': args.prompt}]
        else:
            request = [args.prompt]

        dataset = GPTRequestDataset(request, model.tokenizer,
                                    args.tokens_to_generate,
                                    args.compute_logprobs)
        request_dl = DataLoader(dataset=pad_collate(dataset), batch_size=1)

    # For GPT models that have had soft prompt tuning but you don't want to use any soft prompts
    if not args.use_soft_prompts and model.use_soft_prompts:
        model.use_soft_prompts = False

    response = trainer.predict(model, request_dl)

    print("***************************")
    print(response)
    print("***************************")
    if args.prompt and not args.compute_logprobs:
        print(f'Prompt: {args.prompt}\n\nResponse: {response[0][0][0]}')
コード例 #11
0
ファイル: nlp_overrides.py プロジェクト: quuhua911/NeMo
 def remove_checkpoint(self, filepath: _PATH) -> None:
     # PTL override to accomodate model parallel checkpoints
     filepath = inject_model_parallel_rank(filepath)
     logging.info(f'Removing checkpoint: {filepath}')
     return super().remove_checkpoint(filepath)
コード例 #12
0
ファイル: nlp_overrides.py プロジェクト: quuhua911/NeMo
 def save_checkpoint(self, checkpoint: Dict[str, Any],
                     filepath: _PATH) -> None:
     # PTL override to accomodate model parallel checkpoints
     filepath = inject_model_parallel_rank(filepath)
     return super().save_checkpoint(checkpoint, filepath)