Exemple #1
0
def convert(rank, world_size, args):

    app_state = AppState()
    app_state.data_parallel_rank = 0
    trainer = Trainer(gpus=args.tensor_model_parallel_size)
    # TODO: reach out to PTL For an API-safe local rank override
    trainer.accelerator.training_type_plugin._local_rank = rank

    if args.tensor_model_parallel_size is not None and args.tensor_model_parallel_size > 1:
        # inject model parallel rank
        checkpoint_path = os.path.join(args.checkpoint_folder,
                                       f'mp_rank_{rank:02d}',
                                       args.checkpoint_name)
    else:
        checkpoint_path = os.path.join(args.checkpoint_folder,
                                       args.checkpoint_name)

    if args.model_type == 'gpt':
        model = MegatronGPTModel.load_from_checkpoint(
            checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
    elif args.model_type == 'bert':
        model = MegatronBertModel.load_from_checkpoint(
            checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
    elif args.model_type == 't5':
        model = MegatronT5Model.load_from_checkpoint(
            checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)

    model._save_restore_connector = NLPSaveRestoreConnector()

    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    model.save_to(args.nemo_file_path)

    logging.info(f'NeMo model saved to: {args.nemo_file_path}')
Exemple #2
0
def main(cfg) -> None:
    logging.info("\n\n************** Experiment configuration ***********")
    logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

    plugins = [NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes)]

    trainer = Trainer(plugins=plugins, **cfg.trainer)
    exp_manager(trainer, cfg.exp_manager)
    model = MegatronGPTModel.restore_from(cfg.restore_from_path,
                                          cfg.model,
                                          trainer=trainer)

    # Init all new prompts
    for idx, tag in enumerate(cfg.model.new_prompt_tags):
        init_method = cfg.model.new_prompt_init_methods[idx]

        if init_method == "text":
            init_text = cfg.model.new_prompt_init_text[idx]
            model.init_prompt_from_text(tag, init_text)

        elif init_method == 'random':
            model.init_prompt_from_random(tag)

        else:
            logging.info(
                f'\n Soft prompt init method {init_method} is not recognized, please use text or random'
            )

    logging.info(f'\nCurrent soft prompts include {model.get_prompt_table()}')
    trainer.fit(model)
Exemple #3
0
    def setup_method(self, test_method):
        trainer_config = {
            "devices": 1,
            "num_nodes": 1,
            "accelerator": "gpu",
            "logger": False,
            "precision": 16,
        }
        tensor_model_parallel_size = 1
        pipeline_model_parallel_size = 1
        model_file = '/home/TestData/nlp/megatron_gpt/125M/megatron_gpt.nemo'

        # trainer required for restoring model parallel models
        trainer = Trainer(plugins=NLPDDPPlugin(), **trainer_config)
        assert (
            trainer_config["devices"] *
            trainer_config['num_nodes'] == tensor_model_parallel_size *
            pipeline_model_parallel_size
        ), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size"

        model = MegatronGPTModel.restore_from(restore_path=model_file,
                                              trainer=trainer)
        model.freeze()

        # has to turn off activations_checkpoint_method for inference
        try:
            model.model.language_model.encoder.activations_checkpoint_method = None
        except AttributeError:
            pass

        self.model = model
Exemple #4
0
    def __init__(
        self,
        cfg: DictConfig,
        trainer: Trainer = None,
    ):

        self.cfg = cfg
        self.data_prepared = False

        self.setup_tokenizer(cfg.tokenizer)
        super().__init__(cfg=cfg, trainer=trainer)

        if self.cfg.library == "huggingface":
            self.language_model = AutoModelWithLMHead.from_pretrained(
                cfg.language_model.pretrained_model_name)
            self.language_model.resize_token_embeddings(
                len(self.tokenizer.tokenizer))
        elif self.cfg.library == "megatron":
            self.language_model = MegatronGPTModel.restore_from(
                cfg.language_model.lm_checkpoint, trainer=trainer)
            # 1 corresponds to intent slot; 0 corresponds to squad
            self.prompt_tags = [1, 0] if 'prompt_table' in dir(
                self.language_model) else []
            if hasattr(self.language_model, 'prompt_table'):
                self.language_model.prompt_tuning_param_freeze_and_optimizer_setup(
                )

            # Init all new prompts
            for idx, tag in enumerate(cfg.new_prompt_tags):
                self.prompt_tags.append(tag)
                init_method = cfg.new_prompt_init_methods[idx]
                if init_method == "text":
                    init_text = cfg.new_prompt_init_text[idx]
                    self.language_model.init_prompt_from_text(tag, init_text)
                elif init_method == 'random':
                    self.language_model.init_prompt_from_random(tag)
                else:
                    raise ValueError(
                        f'\n Soft prompt init method {init_method} is not recognized, please use text or random'
                    )

        all_labels = list(
            self._train_dl.dataset.all_possible_labels.union(
                self._validation_dl.dataset.all_possible_labels,
                self._test_dl.dataset.all_possible_labels))
        self.label_to_ids = collections.defaultdict(int)

        for i in range(len(all_labels)):
            self.label_to_ids[all_labels[i]] = i

        self.all_existing_labels = set(self.label_to_ids.keys())

        self.token_to_words = {}
        self.classification_report = ClassificationReport(
            num_classes=len(self.label_to_ids) + 1,
            mode='micro',
            label_ids=self.label_to_ids,
            dist_sync_on_step=True)
        self.eval_mode = cfg.eval_mode
        self.cfg = cfg
Exemple #5
0
def convert(local_rank, rank, world_size, args):

    app_state = AppState()
    app_state.data_parallel_rank = 0
    num_nodes = world_size // args.gpus_per_node
    if args.bcp:
        trainer = Trainer(devices=args.gpus_per_node,
                          num_nodes=num_nodes,
                          accelerator='gpu',
                          plugins=[TorchElasticEnvironment()])
    else:
        trainer = Trainer(devices=args.gpus_per_node,
                          num_nodes=num_nodes,
                          accelerator='gpu')

    app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size
    app_state.tensor_model_parallel_size = args.tensor_model_parallel_size
    app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size

    parallel_state.initialize_model_parallel(
        tensor_model_parallel_size_=app_state.tensor_model_parallel_size,
        pipeline_model_parallel_size_=app_state.pipeline_model_parallel_size,
    )

    app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank(
    )
    app_state.tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank(
    )

    # inject model parallel rank
    checkpoint_path = inject_model_parallel_rank(
        os.path.join(args.checkpoint_folder, args.checkpoint_name))

    logging.info(
        f'rank: {rank}, local_rank: {local_rank}, is loading checkpoint: {checkpoint_path} for tp_rank: {app_state.tensor_model_parallel_rank} and pp_rank: {app_state.pipeline_model_parallel_rank}'
    )

    if args.model_type == 'gpt':
        model = MegatronGPTModel.load_from_checkpoint(
            checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
    elif args.model_type == 'bert':
        model = MegatronBertModel.load_from_checkpoint(
            checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
    elif args.model_type == 't5':
        model = MegatronT5Model.load_from_checkpoint(
            checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
    elif args.model_type == 'nmt':
        model = MegatronNMTModel.load_from_checkpoint(
            checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
    model._save_restore_connector = NLPSaveRestoreConnector()

    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    model.save_to(args.nemo_file_path)

    logging.info(f'NeMo model saved to: {args.nemo_file_path}')
    def __init__(self, cfg: DictConfig, trainer: Trainer):
        super().__init__(cfg, trainer)

        self.model = MegatronGPTModel.restore_from(
            self.register_artifact('language_model.nemo_file',
                                   cfg.language_model.get('nemo_file', None)),
            trainer=trainer,
        )

        self.tokenizer = self.model.tokenizer
        self.float_type = self.model.model.language_model.encoder.layers[
            0].dtype

        if not cfg.use_lm_finetune:
            self.model.freeze()

        hidden_size = self.model.cfg.hidden_size

        self.embeddings = self.model.model.language_model.embedding.word_embeddings

        self.template = cfg.prompt_encoder.template

        self.prompt_encoder = PromptEncoder(
            template=cfg.prompt_encoder.template,
            hidden_size=hidden_size,
            lstm_dropout=cfg.prompt_encoder.dropout,
            num_layers=cfg.prompt_encoder.num_layers,
        )

        self._reduced_loss_buffer = []

        # load prompt encoder
        self.hidden_size = hidden_size
        self.tokenizer.add_special_tokens(
            {'additional_special_tokens': [cfg.pseudo_token]})

        self.pseudo_token_id = self.tokenizer.token_to_id(cfg.pseudo_token)
        self.pad_token_id = self.tokenizer.pad_id if self.tokenizer.pad_id is not None else self.tokenizer.unk_id
        self.spell_length = sum(self.template)
        self.special_tokens = set([
            self.tokenizer.eos_id,
            self.tokenizer.pad_id,
            self.tokenizer.sep_id,
            self.tokenizer.unk_id,
            self.tokenizer.bos_id,
            self.tokenizer.cls_id,
        ])
Exemple #7
0
def main(cfg) -> None:
    logging.info("\n\n************** Experiment configuration ***********")
    logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

    plugins = [NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes)]
    if cfg.trainer.precision == 16:
        plugins.append(
            NLPNativeMixedPrecisionPlugin(
                init_scale=cfg.model.get('native_amp_init_scale', 2**32),
                growth_interval=cfg.model.get('native_amp_growth_interval',
                                              1000),
            ))
    elif cfg.trainer.precision == 'bf16':
        plugins.append(NLPNativeBfloat16PrecisionPlugin())
    else:
        plugins.append(NLPPrecisionPlugin())

    if cfg.get('cluster_type', None) == 'BCP':
        plugins.append(TorchElasticEnvironment())

    trainer = Trainer(plugins=plugins, **cfg.trainer)

    exp_manager(trainer, cfg.exp_manager)

    # update resume from checkpoint found by exp_manager
    resume_from_checkpoint = trainer.resume_from_checkpoint
    if resume_from_checkpoint is not None:
        mp_rank = compute_model_parallel_rank(
            trainer.local_rank, cfg.model.tensor_model_parallel_size)
        resume_from_checkpoint = Path(resume_from_checkpoint)
        resume_from_checkpoint = resume_from_checkpoint.parent.parent.joinpath(
            f'mp_rank_{mp_rank:02d}').joinpath(resume_from_checkpoint.name)
        resume_from_checkpoint = str(resume_from_checkpoint)
        logging.info(
            f'Resuming training from checkpoint: {resume_from_checkpoint}')

    trainer.checkpoint_connector = CheckpointConnector(
        trainer, resume_from_checkpoint=resume_from_checkpoint)
    # Override timer callback to a stateless one
    for idx, callback in enumerate(trainer.callbacks):
        if isinstance(callback, Timer):
            trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time, )

    model = MegatronGPTModel(cfg.model, trainer)

    trainer.fit(model)
    def __init__(
        self,
        cfg: DictConfig,
        trainer: Trainer = None,
    ):

        self.cfg = cfg
        self.data_prepared = False

        self.setup_tokenizer(cfg.tokenizer)
        self.tokenizer.tokenizer.pad_token = self.tokenizer.tokenizer.eos_token
        self.epoch_number = 0
        super().__init__(cfg=cfg, trainer=trainer, no_lm_init=True)

        if self.cfg.library == "huggingface":
            self.language_model = AutoModelWithLMHead.from_pretrained(
                cfg.language_model.pretrained_model_name)
            self.language_model.resize_token_embeddings(
                len(self.tokenizer.tokenizer))
            if self.cfg.language_model.lm_checkpoint:
                self.language_model.load_state_dict(
                    torch.load(self.cfg.language_model.lm_checkpoint))
        elif self.cfg.library == "megatron":
            self.language_model = MegatronGPTModel.restore_from(
                cfg.language_model.lm_checkpoint, trainer=trainer)
            # 1 corresponds to intent slot; 0 corresponds to squad
            self.prompt_tags = [1, 0] if 'prompt_table' in dir(
                self.language_model) else []
            if hasattr(self.language_model, 'prompt_table'):
                self.language_model.prompt_tuning_param_freeze_and_optimizer_setup(
                )

            # Init all new prompts
            for idx, tag in enumerate(cfg.new_prompt_tags):
                self.prompt_tags.append(tag)
                init_method = cfg.new_prompt_init_methods[idx]
                if init_method == "text":
                    init_text = cfg.new_prompt_init_text[idx]
                    self.language_model.init_prompt_from_text(tag, init_text)
                elif init_method == 'random':
                    self.language_model.init_prompt_from_random(tag)
                else:
                    raise ValueError(
                        f'\n Soft prompt init method {init_method} is not recognized, please use text or random'
                    )
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes the PTune TextClassifier model."""
        super().__init__(cfg=cfg, trainer=trainer)

        initialize_model_parallel_for_nemo(
            world_size=trainer.world_size,
            global_rank=trainer.global_rank,
            local_rank=trainer.local_rank,
            tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1),
            seed=cfg.get('seed', 1234),
        )

        # shared params for dataset and data loaders
        self.dataset_cfg = cfg.dataset
        # tokenizer needs to get initialized before the super.__init__()
        # as dataloaders and datasets need it to process the data
        self.tokenizer = get_nmt_tokenizer(
            library=cfg.tokenizer.library,
            model_name=cfg.tokenizer.type,
            tokenizer_model=self.register_artifact("tokenizer.model", cfg.tokenizer.model),
            vocab_file=self.register_artifact("tokenizer.vocab_file", cfg.tokenizer.vocab_file),
            merges_file=self.register_artifact("tokenizer.merges_file", cfg.tokenizer.merge_file),
        )

        self.class_weights = None

        self.model = MegatronGPTModel.restore_from(
            self.register_artifact('language_model.nemo_file', cfg.language_model.get('nemo_file', None)),
            trainer=trainer,
        )

        if not cfg.use_lm_finetune:
            self.model.freeze()

        hidden_size = self.model.cfg.hidden_size

        # register the file containing the labels into the artifacts to get stored in the '.nemo' file later
        self.classes = cfg.dataset.classes

        self.embeddings = self.model.model.language_model.embedding.word_embeddings

        # set allowed vocab set
        self.vocab = self.tokenizer.tokenizer.get_vocab()

        # make sure classes are part of the vocab
        for k in cfg.dataset.classes:
            if token_wrapper(k) not in self.vocab:
                logging.error(f'class {k} is not part of the vocabulary. Please add it to your vocab')
        self.allowed_vocab_ids = set(self.vocab[token_wrapper(k)] for k in cfg.dataset.classes)

        # map from id to label
        self.allowed_vocab = {}
        self.label_ids = {}
        self.id_to_label = {}
        for i, k in enumerate(cfg.dataset.classes):
            self.allowed_vocab[self.vocab[token_wrapper(k)]] = i
            self.label_ids[k] = i
            self.id_to_label[i] = k

        self.template = cfg.prompt_encoder.template

        self.prompt_encoder = PromptEncoder(
            template=cfg.prompt_encoder.template,
            hidden_size=hidden_size,
            lstm_dropout=cfg.prompt_encoder.dropout,
            num_layers=cfg.prompt_encoder.num_layers,
        )

        # load prompt encoder
        self.hidden_size = hidden_size
        self.tokenizer.add_special_tokens({'additional_special_tokens': [cfg.pseudo_token]})

        self.pseudo_token_id = self.tokenizer.tokenizer.get_vocab()[cfg.pseudo_token]
        self.pad_token_id = (
            self.tokenizer.tokenizer.pad_token_id
            if self.tokenizer.tokenizer.pad_token_id is not None
            else self.tokenizer.tokenizer.unk_token_id
        )
        self.spell_length = sum(self.template)
Exemple #10
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--model_file",
                        type=str,
                        default="",
                        required=True,
                        help="Pass path to model's .nemo file")
    parser.add_argument("--prompt",
                        type=str,
                        default="",
                        required=True,
                        help="Prompt for the model (a text to complete)")
    parser.add_argument("--tokens_to_generate",
                        type=int,
                        default="64",
                        required=False,
                        help="How many tokens to add to prompt")
    parser.add_argument(
        "--stop_after_sentence",
        type=bool,
        default="True",
        required=False,
        help=
        "True/False: whether to stop after full sentence has been generated.",
    )
    parser.add_argument(
        "--tensor_model_parallel_size",
        type=int,
        default=1,
        required=True,
    )
    parser.add_argument("--precision",
                        default=32,
                        help="PyTorch Lightning Trainer precision flag")

    args = parser.parse_args()

    # cast precision to int if 32 or 16
    if args.precision in ["32", "16"]:
        args.precision = int(float(args.precision))

    # trainer required for restoring model parallel models
    trainer = Trainer(plugins=NLPDDPPlugin(),
                      gpus=args.tensor_model_parallel_size,
                      precision=args.precision)

    app_state = AppState()
    if args.tensor_model_parallel_size is not None and args.tensor_model_parallel_size > 1:
        app_state.model_parallel_size = args.tensor_model_parallel_size
        app_state.model_parallel_rank = compute_model_parallel_rank(
            trainer.local_rank, app_state.model_parallel_size)

    model = MegatronGPTModel.restore_from(restore_path=args.model_file,
                                          trainer=trainer)

    model.freeze()

    request = {
        "prompt": args.prompt,
        "tokens_to_generate": args.tokens_to_generate,
        "stop_after_sentence": args.stop_after_sentence,
    }

    dataset = GPTRequestDataset(request, model.tokenizer)

    request_dl = DataLoader(dataset)

    response = trainer.predict(model, request_dl)

    print("***************************")
    print(response[0]['completion']['text'])
    print("***************************")
    logging.info(
        f"Generation stopped because: {response[0]['completion']['stop reason']}"
    )
Exemple #11
0
    def __init__(self, cfg: DictConfig, trainer: Trainer):
        super().__init__(cfg, trainer)

        self.cfg = cfg

        # Load pretrained GPT model and tokenizer
        if cfg.get('language_model_path', None):
            self.frozen_model = MegatronGPTModel.restore_from(
                cfg.get('language_model_path'),
                trainer=trainer,
                save_restore_connector=NLPSaveRestoreConnector(),
            )

        # Freeze all GPT model weights for prompt-tuning/p-tuning
        self.frozen_model.freeze()
        self.tokenizer = self.frozen_model.tokenizer
        self.float_type = self.frozen_model.model.language_model.encoder.layers[
            0].dtype
        self.hidden_size = self.frozen_model.cfg.hidden_size
        self.word_embeddings = self.frozen_model.model.language_model.embedding.word_embeddings
        self.existing_tasks = list(self.cfg.get('existing_tasks', []))
        self.new_tasks = list(self.cfg.get('new_tasks', []))

        # Load templates for assigning virtual prompt token positions
        self.load_task_templates(self.cfg.task_templates)

        # Prompt table stores all task embeddings, p-tuning virtual prompts get added to the table after training
        self.prompt_table = PromptTable(
            existing_tasks=self.existing_tasks,
            task_templates=self.task_templates,
            task_id_num_to_name=self.task_id_num_to_name,
            hidden_size=self.hidden_size,
        )
        self._prompt_table_key = VirtualPromptSource.PROMPT_TABLE.value
        self._prompt_encoder_key = VirtualPromptSource.PROMPT_ENCODER.value

        # Prepare pseudo token ids for virtual/virtual prompt tokens
        self.pseudo_tokens = get_pseudo_tokens(self.max_virtual_tokens)
        self.tokenizer.add_special_tokens(
            {'additional_special_tokens': self.pseudo_tokens})
        self.pseudo_token_ids = self.tokenizer.tokens_to_ids(
            self.pseudo_tokens)
        self.pseudo_token_ids_start = self.pseudo_token_ids[0]
        self.pad_token_id = self.tokenizer.pad_id if self.tokenizer.pad_id is not None else self.tokenizer.unk_id
        self.virtual_prompt_style = VirtualPromptStyle(
            cfg.virtual_prompt_style)

        # Prompt tuning stores virtual prompts in the prompt table and tunes their weight directly
        if self.virtual_prompt_style in [
                VirtualPromptStyle.PROMPT_TUNING, VirtualPromptStyle.INFERENCE
        ]:
            self.virtual_prompt_source = VirtualPromptSource.PROMPT_TABLE

        # P-Tuning uses an LSTM Encoder to produce virtual token embeddings
        elif self.virtual_prompt_style == VirtualPromptStyle.P_TUNING:
            self.virtual_prompt_source = VirtualPromptSource.PROMPT_ENCODER
        else:
            raise ValueError(
                f"\nvirtual prompt style '{cfg.virtual_prompt_style}' not recognized, please use one of 'prompt-tuning' or 'p-tuning'"
            )

        self._reduced_loss_buffer = []
        self._inference_config = None

        if self.trainer.precision == 32:
            self.autocast_dtype = torch.float
        elif self.trainer.precision == 16:
            self.autocast_dtype = torch.half
        elif self.trainer.precision == 'bf16':
            self.autocast_dtype = torch.bfloat16
        else:
            raise ValueError('precision must be in [32, 16, "bf16"]')
        # make sure the default pytorch lightning gradient clipping in the basemodel
        self.grad_clip_pl_default = True
        # no support of amp o2
        self.megatron_amp_o2 = False
Exemple #12
0
def main():
    parser = ArgumentParser()

    # args for loading the model, either from .nemo file or from PTL checkpoint
    parser.add_argument("--model_file",
                        type=str,
                        default="",
                        required=False,
                        help="Pass path to model's .nemo file")
    parser.add_argument(
        "--checkpoint_dir",
        type=str,
        default=None,
        required=False,
        help=
        "If not using a .nemo file. Path to PTL checkpoints saved during training. Ex: /raid/nemo_experiments/megatron_gpt/checkpoints",
    )
    parser.add_argument(
        "--checkpoint_name",
        type=str,
        default=None,
        required=False,
        help=
        "If not using a .nemo file. Name of checkpoint to be used. Ex: megatron_gpt--val_loss=6.34-step=649-last.ckpt",
    )

    parser.add_argument(
        "--hparams_file",
        type=str,
        default=None,
        required=False,
        help=
        "If not using a .nemo file. Path to config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml",
    )
    parser.add_argument("--tensor_model_parallel_size",
                        type=int,
                        default=1,
                        required=False,
                        help="Needed if not using a .nemo file")
    parser.add_argument(
        "--pipeline_model_parallel_size",
        type=int,
        default=1,
        required=False,
        help="Needed if not using a .nemo file",
    )

    # PTL Trainer args
    parser.add_argument("--devices",
                        default=1,
                        type=int,
                        help="PyTorch Lightning Trainer devices flag")
    parser.add_argument("--num_nodes",
                        default=1,
                        type=int,
                        help="PyTorch Lightning Trainer num_nodes flag")
    parser.add_argument("--precision",
                        default=16,
                        help="PyTorch Lightning Trainer precision flag")

    # evaluation args
    parser.add_argument("--path_to_file",
                        type=str,
                        default="",
                        required=False,
                        help="Path to file with prompts (a text to complete)")
    parser.add_argument("--prompt",
                        type=str,
                        default="",
                        required=False,
                        help="Prompt for the model (a text to complete)")
    parser.add_argument("--use_soft_prompts",
                        action="store_true",
                        help="Use model's existing soft prompts")
    parser.add_argument("--prompt_tag",
                        type=str,
                        default="",
                        required=False,
                        help="Prompt tag string for task specific soft prompt")
    parser.add_argument("--tokens_to_generate",
                        type=int,
                        default="1",
                        required=False,
                        help="How many tokens to add to prompt")
    parser.add_argument(
        "--stop_after_sentence",
        type=bool,
        default="True",
        required=False,
        help=
        "True/False: whether to stop after full sentence has been generated.",
    )
    parser.add_argument("--batch_size",
                        default=1,
                        type=int,
                        required=False,
                        help="Evaluation batch_size")
    parser.add_argument("--compute_logprobs",
                        type=bool,
                        default=False,
                        required=False,
                        help="Method for logprobs computation")

    args = parser.parse_args()

    assert (
        args.devices * args.num_nodes == args.tensor_model_parallel_size *
        args.pipeline_model_parallel_size
    ), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size"

    if args.model_file and args.checkpoint_dir:
        raise ValueError(
            "Only one of model_file or checkpoint_dir should be used")

    # cast precision to int if 32 or 16
    if args.precision in ["32", "16"]:
        args.precision = int(float(args.precision))

    # trainer required for restoring model parallel models
    trainer = Trainer(
        plugins=[NLPDDPPlugin()],
        devices=args.devices,
        num_nodes=args.num_nodes,
        accelerator='gpu',
        precision=args.precision,
    )

    if args.model_file:
        model = MegatronGPTModel.restore_from(restore_path=args.model_file,
                                              trainer=trainer)
    elif args.checkpoint_dir:
        app_state = AppState()
        if args.tensor_model_parallel_size > 1 or args.pipeline_model_parallel_size > 1:
            app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size
            app_state.tensor_model_parallel_size = args.tensor_model_parallel_size
            app_state.model_parallel_size = args.tensor_model_parallel_size * args.pipeline_model_parallel_size
            (
                app_state.tensor_model_parallel_rank,
                app_state.pipeline_model_parallel_rank,
                app_state.model_parallel_size,
                _,
            ) = fake_initialize_model_parallel(
                world_size=app_state.model_parallel_size,
                rank=trainer.global_rank,
                tensor_model_parallel_size_=app_state.
                tensor_model_parallel_size,
                pipeline_model_parallel_size_=app_state.
                pipeline_model_parallel_size,
            )
        # inject model parallel rank
        checkpoint_path = inject_model_parallel_rank(
            os.path.join(args.checkpoint_dir, args.checkpoint_name))

        model = MegatronGPTModel.load_from_checkpoint(
            checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)

    model.freeze()

    def pad_collate(batch):
        tokens, tokens_to_generate = batch[0]['data'], batch[0][
            'tokens_to_generate']
        compute_logprobs = batch[0]['compute_logprobs']
        lens = [len(token) for token in tokens]

        tokens_pad = pad_sequence(tokens,
                                  batch_first=False,
                                  padding_value=50256)
        data = []

        if 'prompt_tags' in batch[0]:
            # Keep track of soft prompt tags
            prompt_tags = batch[0]['prompt_tags']

            for token, lenn, prompt_tag in zip(tokens_pad.T, lens,
                                               prompt_tags):
                data.append((token, lenn, tokens_to_generate, compute_logprobs,
                             prompt_tag))
        else:
            for token, lenn in zip(tokens_pad.T, lens):
                data.append(
                    (token, lenn, tokens_to_generate, compute_logprobs))

        return data

    # defining type of request
    if args.path_to_file != "":
        request = []
        prompts = open(args.path_to_file, 'r', encoding='utf-8')

        for prompt in prompts.readlines():
            prompt = prompt.split('\n')[0]

            if args.use_soft_prompts and model.use_soft_prompts:
                prompt = json.loads(prompt)

            request.append(prompt)

        dataset = GPTRequestDataset(request, model.tokenizer,
                                    args.tokens_to_generate,
                                    args.compute_logprobs)
        request_dl = DataLoader(dataset=pad_collate(dataset),
                                batch_size=int(args.batch_size))

    else:
        if args.use_soft_prompts and model.use_soft_prompts:
            request = [{'prompt_tag': args.prompt_tag, 'text': args.prompt}]
        else:
            request = [args.prompt]

        dataset = GPTRequestDataset(request, model.tokenizer,
                                    args.tokens_to_generate,
                                    args.compute_logprobs)
        request_dl = DataLoader(dataset=pad_collate(dataset), batch_size=1)

    # For GPT models that have had soft prompt tuning but you don't want to use any soft prompts
    if not args.use_soft_prompts and model.use_soft_prompts:
        model.use_soft_prompts = False

    response = trainer.predict(model, request_dl)

    print("***************************")
    print(response)
    print("***************************")
    if args.prompt and not args.compute_logprobs:
        print(f'Prompt: {args.prompt}\n\nResponse: {response[0][0][0]}')
    def __init__(self, cfg: DictConfig, trainer: Trainer):
        super().__init__(cfg, trainer)

        self.cfg = cfg

        # Load pretrained GPT model and tokenizer
        self.model = MegatronGPTModel.restore_from(
            self.register_artifact('language_model_path',
                                   cfg.get('language_model_path', None)),
            trainer=trainer,
            save_restore_connector=NLPSaveRestoreConnector(),
        )

        # Freeze all GPT model weights for prompt-tuning/p-tuning
        if not cfg.lm_finetune:
            self.model.freeze()

        self.tokenizer = self.model.tokenizer
        self.float_type = self.model.model.language_model.encoder.layers[
            0].dtype
        self.hidden_size = self.model.cfg.hidden_size
        self.word_embeddings = self.model.model.language_model.embedding.word_embeddings
        self.existing_tasks = list(self.cfg.get('existing_tasks', []))
        self.new_tasks = list(self.cfg.get('new_tasks', []))

        # Load templates for assigning virtual prompt token positions
        self.load_task_templates(self.cfg.task_templates)

        # Prompt table stores all task embeddings, p-tuning virtual prompts get added to the table after training
        self.prompt_table = PromptTable(
            existing_tasks=self.existing_tasks,
            task_templates=self.task_templates,
            task_id_num_to_name=self.task_id_num_to_name,
            hidden_size=self.hidden_size,
        )

        # Prepare pseudo token ids for virtual/virtual prompt tokens
        self.pseudo_token_base = cfg.pseudo_token_base
        self.pseudo_tokens = [
            self.pseudo_token_base + str(i)
            for i in range(self.max_virtual_tokens)
        ]
        self.tokenizer.add_special_tokens(
            {'additional_special_tokens': self.pseudo_tokens})
        self.pseudo_token_ids = self.tokenizer.tokens_to_ids(
            self.pseudo_tokens)
        self.pseudo_token_ids_start = self.pseudo_token_ids[0]
        self.pad_token_id = self.tokenizer.pad_id if self.tokenizer.pad_id is not None else self.tokenizer.unk_id
        self.virtual_prompt_style = cfg.virtual_prompt_style.lower()

        # Prompt tuning stores virtual prompts in the prompt table and tunes their weight directly
        if self.virtual_prompt_style in ['prompt-tuning', 'inference']:
            self.virtual_prompt_source = 'prompt-table'

        # P-Tuning uses an LSTM Encoder to produce virtual token embeddings
        elif self.virtual_prompt_style == 'p-tuning':
            self.virtual_prompt_source = 'prompt-encoder'
        else:
            raise ValueError(
                f"\nvirtual prompt style '{cfg.virtual_prompt_type}' not recognized, please use one of 'prompt-tuning' or 'p-tuning'"
            )

        self._reduced_loss_buffer = []
        self._inference_config = None

        if self.trainer.precision == 32:
            self.autocast_dtype = torch.float
        elif self.trainer.precision == 16:
            self.autocast_dtype = torch.half
        elif self.trainer.precision == 'bf16':
            self.autocast_dtype = torch.bfloat16
        else:
            raise ValueError('precision must be in [32, 16, "bf16"]')
def main(cfg) -> None:
    logging.info("\n\n************** Experiment configuration ***********")
    logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

    megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)

    plugins = [
        NLPDDPPlugin(
            no_ddp_communication_hook=
            True,  # we don't use DDP for async grad allreduce
            gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
            find_unused_parameters=False,
        )
    ]
    if cfg.trainer.precision in [16, 'bf16']:
        scaler = None
        if cfg.trainer.precision == 16:
            scaler = GradScaler(
                init_scale=cfg.model.get('native_amp_init_scale', 2**32),
                growth_interval=cfg.model.get('native_amp_growth_interval',
                                              1000),
                hysteresis=cfg.model.get('hysteresis', 2),
            )
        if megatron_amp_o2:
            plugins.append(
                MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision,
                                            device='cuda',
                                            scaler=scaler))
        else:
            plugins.append(
                PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision,
                                             device='cuda',
                                             scaler=scaler))

    if cfg.get('cluster_type', None) == 'BCP':
        plugins.append(TorchElasticEnvironment())

    trainer = Trainer(plugins=plugins, **cfg.trainer)

    exp_manager(trainer, cfg.exp_manager)

    # update resume from checkpoint found by exp_manager
    if cfg.model.resume_from_checkpoint is not None:
        resume_from_checkpoint = cfg.model.resume_from_checkpoint
    else:
        resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path

    logging.info(
        f'Resuming training from checkpoint: {resume_from_checkpoint}')

    trainer._checkpoint_connector = CheckpointConnector(
        trainer, resume_from_checkpoint=resume_from_checkpoint)
    # Override timer callback to a stateless one
    for idx, callback in enumerate(trainer.callbacks):
        if isinstance(callback, Timer):
            trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time, )

    # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
    with open_dict(cfg):
        cfg.model.precision = cfg.trainer.precision

    model = MegatronGPTModel(cfg.model, trainer)

    trainer.fit(model)