Esempio n. 1
0
    def on_validation_epoch_end(self):
        app_state = AppState()
        if hasattr(self, "_train_ds"):
            _reconfigure_microbatch_calculator(
                rank=app_state.global_rank,
                rampup_batch_size=None,
                global_batch_size=self.cfg.data.train_ds.global_batch_size,
                micro_batch_size=self.cfg.data.train_ds.micro_batch_size,
                data_parallel_size=parallel_state.get_data_parallel_world_size(
                ),
            )
        # When running `trainer.validate()`, the training dataset is not available.
        else:
            logging.warning(
                'No training data found, reconfiguring microbatches based on validation batch sizes.'
            )
            _reconfigure_microbatch_calculator(
                rank=app_state.global_rank,
                rampup_batch_size=None,
                global_batch_size=self.cfg.data.validation_ds.
                global_batch_size,
                micro_batch_size=self.cfg.data.validation_ds.micro_batch_size,
                data_parallel_size=parallel_state.get_data_parallel_world_size(
                ),
            )

        return super().on_validation_epoch_end()
Esempio n. 2
0
    def inference_step(self, batch, batch_idx, mode, dataloader_idx=0):
        batch_has_lang_information = len(batch[0]) == 7

        micro_batch_size = batch[0]['text_enc'].size(0)
        # This should happen only on the last batch of the dataset.
        if micro_batch_size != self.cfg.data.validation_ds.micro_batch_size:
            app_state = AppState()
            _reconfigure_microbatch_calculator(
                rank=app_state.global_rank,
                rampup_batch_size=None,
                global_batch_size=micro_batch_size
                * parallel_state.get_data_parallel_world_size()
                * get_num_microbatches(),
                micro_batch_size=micro_batch_size,
                data_parallel_size=parallel_state.get_data_parallel_world_size(),
            )

        # At this point processed_batch is a list of dictionaries where eatch dict is a microbatch.
        # After the process_global_batch call, processed_batch will be a single dictionary containing the global batch.
        # This is required since the parent class expects a single global batch dictioanry.
        processed_batch = self._process_global_batch(batch)

        # Call parent validation step to get the loss.
        # NOTE: There could be extra keys in the processed_batch dictionary such as "langs" for XNLI, this will be ignored in the parent class.
        loss = super().validation_step(processed_batch, batch_idx)

        predicted_token_ids, _ = self.decode(
            tokens_enc=processed_batch['text_enc'], enc_mask=processed_batch['enc_mask'], num_tokens_to_generate=30
        )

        # Special ids to text function to handle stripping <eos> and special tokens with sentencepiece tokenizers.
        preds_text = self.ids_to_text(predicted_token_ids)
        labels_text = self.ids_to_text(processed_batch['labels'])
        input_text = self.ids_to_text(processed_batch['text_enc'])

        if not batch_has_lang_information:
            categories = [None] * len(preds_text)
        else:
            categories = processed_batch['lang']

        metric = self.val_metric[dataloader_idx] if mode == 'validation' else self.test_metric[dataloader_idx]
        assert len(categories) == len(preds_text) == len(labels_text)
        for _, (pred, label, category) in enumerate(zip(preds_text, labels_text, categories)):
            # To compute metrics like pearson or spearman correlation, we need to cast the predicted string and labels to floats.
            pred, label = self.cast_for_metric(
                pred, label, self.val_metric_name if mode == 'validation' else self.test_metric_name
            )
            if batch_has_lang_information:
                _ = metric(pred, label, category)
            else:
                _ = metric(pred, label)

        return {
            'loss': loss,
            'preds': preds_text,
            'labels': labels_text,
            'categories': categories,
            'inputs': input_text,
        }
Esempio n. 3
0
    def eval_step(self, batch, batch_idx, dataloader_idx, data_cfg):
        # Need to squeze dim 0 for tarred datasets since things are pre-batched and we ask the dataloader for batch size 1.
        batch = [[x.squeeze(dim=0) if x.ndim == 3 else x for x in microbatch]
                 for microbatch in batch]
        batch = self.process_global_batch_for_tarred_datasets(batch)
        if data_cfg.dataset_type in ['tarred', 'text']:
            app_state = AppState()
            _reconfigure_microbatch_calculator(
                rank=app_state.global_rank,
                rampup_batch_size=None,
                global_batch_size=batch['text_enc'].size(0) *
                parallel_state.get_data_parallel_world_size(),
                micro_batch_size=batch['text_enc'].size(0),
                data_parallel_size=parallel_state.get_data_parallel_world_size(
                ),
            )
        # This returns the averaged loss across data-parallel groups.
        reduced_loss = super().validation_step(batch, batch_idx)
        tokens_enc, labels, enc_mask = batch['text_enc'], batch[
            'labels'], batch['enc_mask']
        predicted_tokens_ids, _ = self.decode(
            tokens_enc,
            enc_mask,
            tokens_enc.size(1) + self._cfg.
            max_generation_delta,  # Generate up to src-length + max generation delta. TODO: Implement better stopping when everything hits <EOS>.
            tokenizer=self.decoder_tokenizer,
        )
        if self.multilingual:
            source_processor = self.source_processor_list[dataloader_idx]
            target_processor = self.target_processor_list[dataloader_idx]
        else:
            source_processor = self.source_processor
            target_processor = self.target_processor

        # Post-process the translations and inputs to log.
        preds = self.postprocess_outputs(
            outputs=predicted_tokens_ids,
            tokenizer=self.decoder_tokenizer,
            processor=target_processor,
        )
        labels = self.postprocess_outputs(
            outputs=labels,
            tokenizer=self.decoder_tokenizer,
            processor=target_processor,
        )
        encoder_inputs = self.postprocess_outputs(
            outputs=tokens_enc,
            tokenizer=self.encoder_tokenizer,
            processor=source_processor,
        )

        return {
            'inputs': encoder_inputs,
            'translations': preds,
            'ground_truths': labels,
            'loss': reduced_loss,
        }
Esempio n. 4
0
 def on_validation_epoch_start(self):
     app_state = AppState()
     _reconfigure_microbatch_calculator(
         rank=app_state.global_rank,
         rampup_batch_size=None,
         global_batch_size=self.cfg.data.validation_ds.global_batch_size,
         micro_batch_size=self.cfg.data.validation_ds.micro_batch_size,
         data_parallel_size=parallel_state.get_data_parallel_world_size(),
     )
     return super().on_validation_epoch_start()
Esempio n. 5
0
 def on_validation_epoch_end(self):
     app_state = AppState()
     if hasattr(self, "_train_ds"):
         _reconfigure_microbatch_calculator(
             rank=app_state.global_rank,
             rampup_batch_size=None,
             global_batch_size=self._cfg.train_ds.global_batch_size,
             micro_batch_size=self._cfg.train_ds.micro_batch_size,
             data_parallel_size=parallel_state.get_data_parallel_world_size(
             ),
         )
Esempio n. 6
0
    def _test(self, rampup_batch_size: Optional[List[int]]) -> None:
        for data_parallel_size in range(1, self.world_size + 1):

            expected_global_batch_size = self.GLOBAL_BATCH_SIZE
            expected_micro_batch_size = self.MICRO_BATCH_SIZE
            if rampup_batch_size:
                expected_global_batch_size = rampup_batch_size[0]
                num_consumed_samples = 0
                step_of_global_batch_size = rampup_batch_size[1]
                threshold = rampup_batch_size[2]

            if data_parallel_size > 1 and data_parallel_size % 2 != 0:
                continue
            if self.world_size % data_parallel_size != 0:
                continue
            with self.subTest(data_parallel_size=data_parallel_size):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=self.world_size //
                    data_parallel_size,
                    pipeline_model_parallel_size_=1,
                )
                self.assertEqual(data_parallel_size,
                                 parallel_state.get_data_parallel_world_size())

                _reconfigure_microbatch_calculator(
                    self.rank,
                    rampup_batch_size,
                    self.GLOBAL_BATCH_SIZE,
                    self.MICRO_BATCH_SIZE,
                    data_parallel_size,
                )

                self.assertEqual(get_micro_batch_size(),
                                 expected_micro_batch_size)
                self.assertEqual(
                    get_num_microbatches(), expected_global_batch_size /
                    expected_micro_batch_size / data_parallel_size)
                current_global_batch_size = get_current_global_batch_size()
                self.assertEqual(current_global_batch_size,
                                 expected_global_batch_size)

                # Make sure `global_batch_size` equals to the final global batch size after
                # certain number of updates.
                if rampup_batch_size:
                    update_num_microbatches(current_global_batch_size)
                    for i in range(100):
                        current_global_batch_size = get_current_global_batch_size(
                        )
                        update_num_microbatches(current_global_batch_size)
                    current_global_batch_size = get_current_global_batch_size()
                    self.assertEqual(get_current_global_batch_size(),
                                     self.GLOBAL_BATCH_SIZE)
                parallel_state.destroy_model_parallel()
Esempio n. 7
0
 def training_step(self, batch, batch_idx):
     # Need to squeze dim 0 for tarred datasets since things are pre-batched and we ask the dataloader for batch size 1.
     batch = [[x.squeeze(dim=0) if x.ndim == 3 else x for x in microbatch]
              for microbatch in batch]
     batch = self.process_global_batch_for_tarred_datasets(batch)
     app_state = AppState()
     _reconfigure_microbatch_calculator(
         rank=app_state.global_rank,
         rampup_batch_size=None,
         global_batch_size=batch['text_enc'].size(0) *
         parallel_state.get_data_parallel_world_size(),
         micro_batch_size=batch['text_enc'].size(0),
         data_parallel_size=parallel_state.get_data_parallel_world_size(),
     )
     return super().training_step(batch, batch_idx)
Esempio n. 8
0
 def training_step(self, batch, batch_idx):
     micro_batch_size = batch[0]['text_enc'].size(0)
     # This should happen only on the last batch of the dataset.
     if micro_batch_size != self.cfg.data.train_ds.micro_batch_size:
         app_state = AppState()
         _reconfigure_microbatch_calculator(
             rank=app_state.global_rank,
             rampup_batch_size=None,
             global_batch_size=micro_batch_size *
             parallel_state.get_data_parallel_world_size() *
             get_num_microbatches(),
             micro_batch_size=micro_batch_size,
             data_parallel_size=parallel_state.get_data_parallel_world_size(
             ),
         )
     return super().training_step(batch, batch_idx)
Esempio n. 9
0
 def training_step(self, batch, batch_idx):
     micro_batch_size = batch[0]['text_enc'].size(0)
     # This should happen only on the last batch of the dataset.
     if micro_batch_size != self.cfg.data.train_ds.micro_batch_size:
         app_state = AppState()
         _reconfigure_microbatch_calculator(
             rank=app_state.global_rank,
             rampup_batch_size=None,
             global_batch_size=micro_batch_size
             * parallel_state.get_data_parallel_world_size()
             * get_num_microbatches(),
             micro_batch_size=micro_batch_size,
             data_parallel_size=parallel_state.get_data_parallel_world_size(),
         )
     # At this point batch is a list of dictionaries where eatch dict is a microbatch.
     # After the process_global_batch call, batch will be a single dictionary containing the global batch.
     # This is required since the parent class expects a single global batch dictioanry.
     batch = self._process_global_batch(batch)
     return super().training_step(batch, batch_idx)
Esempio n. 10
0
 def training_step(self, batch, batch_idx):
     # Need to squeze dim 0 for tarred datasets since things are pre-batched and we ask the dataloader for batch size 1.
     if self._cfg.train_ds.dataset_type in ['tarred', 'text']:
         batch = [[
             x.squeeze(dim=0) if x.ndim == 3 else x for x in microbatch
         ] for microbatch in batch]
         batch = self.process_global_batch_for_tarred_datasets(batch)
     elif (self._cfg.train_ds.dataset_type in ['bin_memmap', 'text_memmap']
           and self._cfg.train_ds.get("sampler",
                                      "distributed") == 'distributed'):
         batch = self._process_global_batch_without_megatron_batch_sampler(
             batch, tokenizer=self.encoder_tokenizer)
     if self._cfg.train_ds.dataset_type in ['tarred', 'text']:
         app_state = AppState()
         _reconfigure_microbatch_calculator(
             rank=app_state.global_rank,
             rampup_batch_size=None,
             global_batch_size=batch['text_enc'].size(0) *
             parallel_state.get_data_parallel_world_size(),
             micro_batch_size=batch['text_enc'].size(0),
             data_parallel_size=parallel_state.get_data_parallel_world_size(
             ),
         )
     return super().training_step(batch, batch_idx)
Esempio n. 11
0
def tab_sample_sequence_batch(
    model,
    context_tokens,
    context_lengths,
    attention_mask,
    position_ids,
    tokens_to_generate,
    all_probs=True,
    type_ids=None,
    temperature=None,
):
    app_state = AppState()
    micro_batch_size = context_tokens.shape[0]
    _reconfigure_microbatch_calculator(
        rank=app_state.global_rank,
        rampup_batch_size=None,
        global_batch_size=micro_batch_size,
        micro_batch_size=micro_batch_size,
        data_parallel_size=1,
    )
    tokenizer = model.tokenizer
    sizes = tokenizer.code_column.sizes
    tokens_per_row = sum(sizes) + 1
    columns = tokenizer.code_column.columns
    num_columns = len(columns)
    tokenid_range = []
    for i in range(num_columns):
        tokenid_range.extend(tokenizer.code_column.get_range(i))

    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
        context = context_tokens[:, :context_length]
        # the context may start in the middle of the row,
        # calculate the offset according to the position of '\n' or '<|endoftext|>'
        positions = torch.where(context == tokenizer.eor)[1]
        if len(positions) == 0:
            positions = torch.where(context == tokenizer.eod)[1]
        if len(positions) != 0:
            max_position = positions.max().item()
            # TODO, need to make sure context of different batch have the same offset lengths")
            # otherwise, need to calculate offset per batch_id
            offset = (context_length - max_position - 1) % tokens_per_row
        else:
            offset = 0

        eod_id = tokenizer.eos_id

        counter = 0

        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        output_logits = None

        # Generate enough tokens for the longest sequence
        maxlen = tokens_to_generate + context_lengths.max().item()

        if maxlen > model.cfg.encoder_seq_length:
            maxlen = model.cfg.encoder_seq_length

        lengths = torch.ones([batch_size]).long().cuda() * maxlen

        while context_length < maxlen:
            # types2use = None
            if counter == 0:
                # Allocate memory for the entire context.
                set_inference_key_value_memory = True
                tokens2use = tokens[:, :context_length]
                positions2use = position_ids[:, :context_length]
                # not using type2use. uncomment it if it is used
                # if type_ids is not None:
                #     types2use = type_ids[:, :context_length]
            else:
                # Set this to false so the memory is not reallocated.
                set_inference_key_value_memory = False
                tokens2use = tokens[:, context_length - 1].view(batch_size, -1)
                positions2use = position_ids[:, context_length - 1].view(
                    batch_size, -1)
                # not using type2use. uncomment it if it is used
                # if type_ids is not None:
                #     types2use = type_ids[:, context_length - 1].view(batch_size, -1)
            # micro_batch_size = 2
            attention_mask_repeat = torch.concat(
                [attention_mask for _ in range(micro_batch_size)])
            setkey_value_array = torch.tensor(
                [set_inference_key_value_memory] * micro_batch_size,
                device=torch.cuda.current_device())
            len_array = torch.tensor([maxlen] * micro_batch_size,
                                     device=torch.cuda.current_device())
            batch = [
                tokens2use, attention_mask_repeat, positions2use,
                setkey_value_array, len_array
            ]
            tensor_shape = [
                tokens2use.shape[1], micro_batch_size, model.cfg.hidden_size
            ]

            output = forward_step(model, batch, tensor_shape)

            if parallel_state.is_pipeline_last_stage():
                output = output[0]['logits'].float()
                output = tensor_parallel.gather_from_tensor_model_parallel_region(
                    output)
                assert output is not None
                output = output.float()
                logits = output[:, -1].view(batch_size, -1).contiguous()
                token_in_row = (counter + offset) % tokens_per_row
                logits = logits.float()
                logits /= temperature
                if token_in_row == tokens_per_row - 1:
                    # line break
                    eor_id = tokenizer.eor
                    eod_id = tokenizer.eos_id
                    min_id = min(eor_id, eod_id)
                    max_id = max(eor_id, eod_id) + 1
                    logits = tab_logits(logits, min_id, max_id)
                else:
                    # limit the range
                    min_id, max_id = tokenid_range[token_in_row]
                    logits = tab_logits(logits, min_id, max_id)
                log_probs = F.softmax(logits, dim=-1)
                prev = torch.multinomial(log_probs, num_samples=1).view(-1)
                started = context_lengths <= context_length
                # Clamp the out of vocabulary tokens.
                prev = torch.clamp(prev, max=tokenizer.vocab_size - 1)

                new_tokens = switch(tokens[:, context_length].view(-1), prev,
                                    started)
                tokens[:, context_length] = new_tokens

                if output_logits is None:
                    output_context = F.log_softmax(
                        output[:, :context_length, :], 2)
                    indices = torch.unsqueeze(tokens[:, 1:context_length + 1],
                                              2)
                    output_logits = torch.gather(output_context, 2,
                                                 indices).squeeze(2)
                    if all_probs:
                        full_logits = output_context
                else:
                    output_context = F.log_softmax(output, 2)
                    indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2)
                    new_output_logits = torch.gather(output_context, 2,
                                                     indices).squeeze(2)

                    # TODO(rprenger) we're copying output_logits every time.  Should pre-allocate
                    output_logits = torch.cat(
                        [output_logits, new_output_logits], 1)
                    if all_probs:
                        full_logits = torch.cat([full_logits, output_context],
                                                1)

                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eod_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                if all_probs:
                    yield tokens, lengths, output_logits, full_logits
                else:
                    yield tokens, lengths, output_logits, None

            else:
                if parallel_state.is_pipeline_first_stage():
                    src = parallel_state.get_pipeline_model_parallel_last_rank(
                    )
                    group = parallel_state.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None, None, None
                else:
                    yield None, None, None, None

                done = torch.cuda.ByteTensor([0])
                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)

            context_length += 1
            counter += 1
            if done:
                break
Esempio n. 12
0
    def _forward_backward_test_impl(
        self,
        forward_only: bool,
        fwd_bwd_func: FwdStepFunc,
        pipeline_model_parallel_world_size: Optional[int],
        virtual_pipeline_model_parallel_size: Optional[int],
        async_comm: bool = False,
        *,
        default_backend: Optional[str] = None,
        p2p_backend: Optional[str] = None,
    ) -> None:
        if fwd_bwd_func == _forward_backward_pipelining_with_interleaving:
            self.assertIsNotNone(virtual_pipeline_model_parallel_size)
            self.assertGreater(virtual_pipeline_model_parallel_size, 1)
        dtype_options = self.dtypes or [torch.float32, torch.double
                                        ] + _get_autocast_dtypes()

        for dtype, deallocate_pipeline_outputs in itertools.product(
                dtype_options,
                self.deallocate_options,
        ):
            grad_scaler = (torch.cuda.amp.GradScaler(
                init_scale=4.0) if dtype == torch.half else None)

            (tensor_model_parallel_world_size, data_parallel_size,
             pipeline_model_parallel_world_size
             ) = _get_default_world_sizes_model_parallel_world_size(
                 pipeline_model_parallel_world_size)

            parallel_state.initialize_model_parallel(
                tensor_model_parallel_size_=tensor_model_parallel_world_size,
                pipeline_model_parallel_size_=
                pipeline_model_parallel_world_size,
                virtual_pipeline_model_parallel_size_=
                virtual_pipeline_model_parallel_size,
                default_backend=default_backend,
                p2p_backend=p2p_backend,
            )
            pp_utils._reconfigure_microbatch_calculator(
                rank=parallel_state.get_tensor_model_parallel_rank(),
                rampup_batch_size=None,
                global_batch_size=self.GLOBAL_BATCH_SIZE,
                micro_batch_size=self.MICRO_BATCH_SIZE,
                data_parallel_size=parallel_state.get_data_parallel_world_size(
                ),
            )

            global_batch_shape = (
                self.GLOBAL_BATCH_SIZE //
                parallel_state.get_data_parallel_world_size(),
                self.HIDDEN_SIZE,
                self.HIDDEN_SIZE,
            )

            batch = None
            if parallel_state.is_pipeline_first_stage():
                batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(), )

            model = build_model(
                testing_utils.model_provider_func,
                # Use DDP only when it's better to have
                wrap_with_ddp=data_parallel_size > 1,
                virtual_pipeline_model_parallel_size=
                virtual_pipeline_model_parallel_size,
                hidden_size=self.HIDDEN_SIZE,
            )

            offset = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None else 0
            for idx, model_module in enumerate(model):
                model_module = model_module.to(dtype)
                model_module.apply(get_init_weights_func(idx * offset))

            _param_groups = _get_params_for_weight_decay_optimization(model)
            optimizer = torch.optim.Adam(_param_groups, lr=1e-3)

            pp_utils.update_num_microbatches(0)

            loss = fwd_bwd_func(
                testing_utils.fwd_step_func,
                batch,
                model,
                forward_only=forward_only,
                # `tensor_shape` is the shape of micro batch.
                tensor_shape=(
                    self.MICRO_BATCH_SIZE,
                    self.HIDDEN_SIZE,
                    self.HIDDEN_SIZE,
                ),
                dtype=dtype,
                async_comm=async_comm,
                grad_scaler=grad_scaler,
                deallocate_pipeline_output=deallocate_pipeline_outputs,
            )

            if dtype == torch.double:
                hidden_size = self.HIDDEN_SIZE
                microbatch_size = self.MICRO_BATCH_SIZE
                total_layers = pipeline_model_parallel_world_size
                if virtual_pipeline_model_parallel_size is not None:
                    total_layers *= virtual_pipeline_model_parallel_size
                target_loss, target_model = get_target_loss_and_model(
                    global_batch_shape, hidden_size, total_layers)

                for loss_item in loss:
                    x = loss_item['avg']
                    torch.testing.assert_close(x.item() / microbatch_size,
                                               target_loss.item())

                if not forward_only:
                    for vm_id, model_module in enumerate(model):
                        params = list(model_module.parameters())
                        rank = params[0].get_device()
                        offset = pipeline_model_parallel_world_size
                        param_id = rank // data_parallel_size + vm_id * offset
                        target_params = target_model[param_id]

                        torch.testing.assert_close(params[0].cpu(),
                                                   target_params[0])
                        torch.testing.assert_close(params[1].cpu(),
                                                   target_params[1])
                        torch.testing.assert_close(
                            params[0].grad.cpu() / microbatch_size,
                            target_params[0].grad)
                        torch.testing.assert_close(
                            params[1].grad.cpu() / microbatch_size,
                            target_params[1].grad)

            if not forward_only:
                for m in model:
                    for p in m.parameters():
                        self.assertIsNotNone(p.grad)
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)

            parallel_state.destroy_model_parallel()
    def complete(self, request: Dict):
        """
            Autoregressively invokes language model in the inference mode
        Args:
            request: Dictionary with the following fields
                * prompt: a string which text the model should complete.
                * tokens_to_generate: how many tokens to generate while doing prompt completion.
        Returns:
            response: A python dictionary with the following fields
                * prompt: original text of the prompt
                * tokenized_prompt: list of (str) tokens from prompt
                * completion: a python dictionary with the following subfields:
                    * tokens: a list of triples (token, token_id, log_prob) comprising completion
                    * text: completion text (as a single string)

        """
        app_state = AppState()

        # The complete method only works with global batch = micro batch size = data parallel size = 1.
        _reconfigure_microbatch_calculator(
            rank=app_state.global_rank,
            rampup_batch_size=None,
            global_batch_size=1,
            micro_batch_size=1,
            data_parallel_size=1,
        )
        app_state = AppState()

        response = {}
        self.freeze()
        # naive greedy slow loop
        # TODO: add option for BeamSearchDecoder

        response['prompt'] = request['prompt'][0]
        response['completion'] = {}
        tokens_enc = request['masked_sample']

        response['masked_input'] = ' '.join(self.tokenizer.ids_to_tokens(tokens_enc[0].cpu().numpy().tolist()))
        enc_mask = tokens_enc != self.tokenizer.pad_id

        predicted_tokens_ids, log_probs = self.decode(tokens_enc, enc_mask, int(request['tokens_to_generate']))
        predicted_tokens_ids = predicted_tokens_ids.cpu().numpy()[0].tolist()
        log_probs = log_probs.cpu().numpy()[0].tolist()
        if self.tokenizer.eos_id in predicted_tokens_ids:
            idx = predicted_tokens_ids.index(self.tokenizer.eos_id)
            predicted_tokens_ids = predicted_tokens_ids[:idx]
        else:
            predicted_tokens_ids = [id for id in predicted_tokens_ids if id != self.tokenizer.pad_id]
        if self.tokenizer.eos_id in predicted_tokens_ids:
            idx = predicted_tokens_ids.index(self.tokenizer.eos_id)
            predicted_tokens_ids = predicted_tokens_ids[:idx]
        # Legacy sentencepiece detokenization still preserves special tokens which messes up exact string match.
        if hasattr(self.tokenizer, 'special_token_to_id'):
            predicted_tokens_ids = [
                id for id in predicted_tokens_ids if id not in self.tokenizer.special_token_to_id.values()
            ]

        predicted_tokens_dec = self.tokenizer.ids_to_tokens(predicted_tokens_ids)
        response['completion']['text'] = self.tokenizer.tokens_to_text(predicted_tokens_dec)
        response['completion']['tokens'] = list(zip(predicted_tokens_ids, predicted_tokens_dec, log_probs))
        self.unfreeze()
        return response
    def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, encoder_input=None):
        app_state = AppState()
        global_batch_per_gpu = tokens_enc.size(0)

        num_micro_batches_before_decode = get_num_microbatches()
        # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc".
        # TODO: reconfigure back to how things were before decode?
        _reconfigure_microbatch_calculator(
            rank=app_state.global_rank,
            rampup_batch_size=None,
            global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(),
            micro_batch_size=global_batch_per_gpu,  # Make sure that there is no "grad acc" while decoding.
            data_parallel_size=parallel_state.get_data_parallel_world_size(),
        )
        predicted_tokens_dec = (
            torch.LongTensor([self.tokenizer.bos_id] * global_batch_per_gpu).unsqueeze(1).to(tokens_enc.device)
        )
        encoder_seq_length = tokens_enc.size(1)
        tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size]
        assert predicted_tokens_dec.size(0) == global_batch_per_gpu

        for i in range(num_tokens_to_generate):
            # No microbatches in decoding. Just the global batch.
            decoder_seq_length = predicted_tokens_dec.size(1)
            dec_mask = predicted_tokens_dec != self.tokenizer.pad_id

            if encoder_input is not None:
                batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask, encoder_input]
            else:
                batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask]
            if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
                output_tensor = forward_backward_pipelining_without_interleaving(
                    forward_step_func=self.get_forward_output_only_func(),
                    batch=batch_for_pipeline,
                    model=self.enc_dec_model,
                    forward_only=True,
                    tensor_shape=tensor_shape,
                    decoder_sequence_length=decoder_seq_length,
                    dtype=self.autocast_dtype,
                )
            else:
                output_tensor = forward_backward_no_pipelining(
                    forward_step_func=self.get_forward_output_only_func(),
                    batch=batch_for_pipeline,
                    model=self.enc_dec_model,
                    forward_only=True,
                    tensor_shape=tensor_shape,
                    decoder_sequence_length=decoder_seq_length,
                    dtype=self.autocast_dtype,
                )

            # get output tensor
            if parallel_state.is_pipeline_last_stage():
                output_tensor = output_tensor[0]['logits']
                output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor)
                log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(output_tensor, dim=-1), dim=-1)
                predicted_tokens_dec = torch.cat(
                    [predicted_tokens_dec.to(token_ids.device), token_ids[:, -1].unsqueeze(1)], dim=1
                )
            else:
                log_probs = torch.zeros(
                    (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1]), dtype=self.autocast_dtype
                ).cuda()
                predicted_tokens_dec = torch.zeros(
                    (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1),
                    dtype=predicted_tokens_dec.dtype,
                ).cuda()

            if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
                # Broadcast from the last pipeline stage to all other model-parallel ranks.
                torch.distributed.broadcast(
                    predicted_tokens_dec,
                    parallel_state.get_pipeline_model_parallel_last_rank(),
                    group=parallel_state.get_model_parallel_group(),
                )
                torch.distributed.broadcast(
                    log_probs,
                    parallel_state.get_pipeline_model_parallel_last_rank(),
                    group=parallel_state.get_model_parallel_group(),
                )

        # Reset microbatch calculator to what it was before decoding.
        _reconfigure_microbatch_calculator(
            rank=app_state.global_rank,
            rampup_batch_size=None,
            global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(),
            micro_batch_size=global_batch_per_gpu // num_micro_batches_before_decode,
            data_parallel_size=parallel_state.get_data_parallel_world_size(),
        )
        return predicted_tokens_dec, log_probs
Esempio n. 15
0
    def inference_step(self, batch, batch_idx, mode, dataloader_idx=0):
        batch_has_lang_information = len(batch[0]) == 7
        # XNLI Batches have language information that need to be removed before calling the parent validation step.
        if batch_has_lang_information:
            processed_batch = []
            for micro_batch in batch:
                micro_batch = {
                    k: v
                    for k, v in micro_batch.items() if k != 'lang'
                }
                processed_batch.append(micro_batch)
        else:
            processed_batch = batch

        micro_batch_size = processed_batch[0]['text_enc'].size(0)
        # This should happen only on the last batch of the dataset.
        if micro_batch_size != self.cfg.data.validation_ds.micro_batch_size:
            app_state = AppState()
            _reconfigure_microbatch_calculator(
                rank=app_state.global_rank,
                rampup_batch_size=None,
                global_batch_size=micro_batch_size *
                parallel_state.get_data_parallel_world_size() *
                get_num_microbatches(),
                micro_batch_size=micro_batch_size,
                data_parallel_size=parallel_state.get_data_parallel_world_size(
                ),
            )

        # Call parent validation step to get the loss.
        loss = super().validation_step(processed_batch, batch_idx)

        # Remainder of the code is to run the decoding loop, and compute accuracies.
        if batch_has_lang_information:
            tokens_enc, _, _, labels, enc_mask, _, langs = self.process_global_batch(
                batch)
        else:
            tokens_enc, _, _, labels, enc_mask, _ = self.process_global_batch(
                batch)

        predicted_token_ids, _ = self.decode(tokens_enc=tokens_enc,
                                             enc_mask=enc_mask,
                                             num_tokens_to_generate=30)

        preds_text, labels_text = self.preds_and_labels_to_text(
            predicted_token_ids, labels)

        if not batch_has_lang_information:
            if (mode == 'validation'
                    and hasattr(self.cfg.data.validation_ds, "names") and
                    isinstance(self.cfg.data.validation_ds.names, ListConfig)):
                categories = [
                    self.cfg.data.validation_ds.names[dataloader_idx]
                ] * len(preds_text)
            elif (mode == 'test' and hasattr(self.cfg.data.test_ds, "names")
                  and isinstance(self.cfg.data.test_ds.names, ListConfig)):
                categories = [self.cfg.data.test_ds.names[dataloader_idx]
                              ] * len(preds_text)
            else:
                categories = [None] * len(preds_text)
        else:
            categories = langs

        metric = self.val_metric if mode == 'validation' else self.test_metric
        assert len(categories) == len(preds_text) == len(labels_text)
        for _, (pred, label,
                category) in enumerate(zip(preds_text, labels_text,
                                           categories)):
            _ = metric(pred, label, category)

        return {
            'loss': loss,
            'preds': preds_text,
            'labels': labels_text,
            'categories': categories
        }
Esempio n. 16
0
def sample_sequence_batch(
    model,
    context_tokens,
    context_lengths,
    task_ids,
    attention_mask,
    position_ids,
    tokens_to_generate,
    all_probs=False,
    type_ids=None,
    temperature=None,
    extra={},
):
    # Importing here to avoid circular import errors
    from nemo.collections.nlp.models.language_modeling import MegatronGPTPromptLearningModel

    app_state = AppState()
    micro_batch_size = context_tokens.shape[0]
    _reconfigure_microbatch_calculator(
        rank=app_state.global_rank,
        rampup_batch_size=None,
        global_batch_size=micro_batch_size,
        micro_batch_size=micro_batch_size,
        data_parallel_size=1,
    )
    tokenizer = model.tokenizer
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()

        # added eos_id to support the function generate_samples_eval that passes
        # eos_id as an argument and needs termination when that id id found.
        eod_id = tokenizer.eos_id
        counter = 0

        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        output_logits = None
        all_generated_indices = None  # used to track all generated indices
        # Generate enough tokens for the longest sequence
        maxlen = tokens_to_generate + context_lengths.max().item()

        if maxlen > model.cfg.encoder_seq_length + 1:
            maxlen = model.cfg.encoder_seq_length + 1

        lengths = torch.ones([batch_size]).long().cuda() * maxlen

        while context_length < maxlen:
            # types2use = None
            if counter == 0:
                # Allocate memory for the entire context.
                set_inference_key_value_memory = True
                tokens2use = tokens[:, :context_length]
                positions2use = position_ids[:, :context_length]
                # not using type2use. uncomment it if it is used
                # if type_ids is not None:
                #     types2use = type_ids[:, :context_length]
            else:
                # Set this to false so the memory is not reallocated.
                set_inference_key_value_memory = False
                tokens2use = tokens[:, context_length - 1].view(batch_size, -1)
                positions2use = position_ids[:, context_length - 1].view(
                    batch_size, -1)
                # not using type2use. uncomment it if it is used
                # if type_ids is not None:
                #     types2use = type_ids[:, context_length - 1].view(batch_size, -1)

            attention_mask_repeat = torch.concat(
                [attention_mask for _ in range(micro_batch_size)])
            setkey_value_array = torch.tensor(
                [set_inference_key_value_memory] * micro_batch_size,
                device=torch.cuda.current_device())
            len_array = torch.tensor([maxlen] * micro_batch_size,
                                     device=torch.cuda.current_device())

            # Only prompt learning models will have a prompt table, and require task ids
            if isinstance(model, MegatronGPTPromptLearningModel):
                batch = [
                    tokens2use, attention_mask_repeat, positions2use, task_ids,
                    setkey_value_array, len_array
                ]
                tensor_shape = [
                    tokens2use.shape[1], micro_batch_size,
                    model.frozen_model.cfg.hidden_size
                ]
            else:
                batch = [
                    tokens2use, attention_mask_repeat, positions2use,
                    setkey_value_array, len_array
                ]
                tensor_shape = [
                    tokens2use.shape[1], micro_batch_size,
                    model.cfg.hidden_size
                ]

            output = forward_step(model, batch, tensor_shape)

            if parallel_state.is_pipeline_last_stage():
                output = output[0]['logits'].float()
                output = tensor_parallel.gather_from_tensor_model_parallel_region(
                    output)
                assert output is not None
                output = output.float()
                logits = output[:, -1].view(batch_size, -1).contiguous()

                # make sure it will generate at least min_length
                min_length = extra.get('min_tokens_to_generate', 0)
                if min_length > 0:
                    within_min_length = (context_length -
                                         context_lengths) < min_length
                    logits[within_min_length, eod_id] = -float('Inf')

                # make sure it won't sample outside the vocab_size range
                logits[:, tokenizer.vocab_size:] = -float('Inf')

                if extra.get('greedy', False):
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
                    logits = logits.float()
                    logits /= temperature
                    # handle repetition penality
                    logits = repetition_penalty(
                        logits, extra.get('repetition_penalty', 1.2),
                        all_generated_indices)
                    logits = top_k_logits(logits,
                                          top_k=extra.get('top_k', 0),
                                          top_p=extra.get('top_p', 0.9))
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)
                started = context_lengths <= context_length

                # Clamp the predicted out of vocabulary tokens
                prev = torch.clamp(prev, max=tokenizer.vocab_size - 1)
                new_tokens = switch(tokens[:, context_length].view(-1), prev,
                                    started)

                # Replace sampled tokens w/ done token if EOD has already been sampled
                new_tokens = switch(new_tokens, eod_id, is_done)

                # Replace special soft prompt token ids with unk token ids
                if isinstance(model, MegatronGPTPromptLearningModel):
                    pseudo_token_ids_start = model.pseudo_token_ids_start
                    new_tokens[(new_tokens >=
                                pseudo_token_ids_start)] = tokenizer.unk_id
                    tokens[:, :context_length][(
                        tokens[:, :context_length] >=
                        pseudo_token_ids_start)] = tokenizer.unk_id

                # Insert either new predicted or next prompt token
                tokens[:, context_length] = new_tokens

                if output_logits is None:
                    output = F.log_softmax(output[:, :context_length, :], 2)
                    indices = torch.unsqueeze(tokens[:, 1:context_length + 1],
                                              2)
                    output_logits = torch.gather(output, 2, indices).squeeze(2)
                    all_generated_indices = indices[:, :, 0]
                    if all_probs:
                        full_logits = output
                else:
                    output = F.log_softmax(output, 2)
                    indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2)
                    new_output_logits = torch.gather(output, 2,
                                                     indices).squeeze(2)

                    # TODO(rprenger) we're copying output_logits every time.  Should pre-allocate
                    output_logits = torch.cat(
                        [output_logits, new_output_logits], 1)
                    all_generated_indices = torch.cat(
                        [all_generated_indices, indices[:, :, 0]], 1)
                    if all_probs:
                        full_logits = torch.cat([full_logits, output], 1)

                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eod_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                if all_probs:
                    yield tokens, lengths, output_logits, full_logits
                else:
                    yield tokens, lengths, output_logits, None

            else:
                if parallel_state.is_pipeline_first_stage():
                    src = parallel_state.get_pipeline_model_parallel_last_rank(
                    )
                    group = parallel_state.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None, None, None
                else:
                    yield None, None, None, None

                done = torch.cuda.ByteTensor([0])
                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)

            context_length += 1
            counter += 1
            if done:
                break
Esempio n. 17
0
def forward_backward_func_template(
    args,
    name: str,
    forward_backward_func,
    pipeline_model_parallel_size: int,
    forward_only: bool,
    dtype: torch.dtype,
    grad_scaler: Optional[GradScaler],
    deallocate_pipeline_outputs: bool,
    data_parallel_size: int,
) -> None:
    print_separator(
        f"{name}, {dtype}, use grad_scaler: {grad_scaler is not None}, "
        f"deallocate_pipeline_outputs: {deallocate_pipeline_outputs}, "
        f"pipeline parallel size: {pipeline_model_parallel_size}, "
        f"data parallel size: {data_parallel_size}")
    virtual_pipeline_model_parallel_size = 2 if name == "interleaving" else None
    if name == "no_pipelining":
        # note (mkozuki): `forward_backward_no_pipelining` is **NOT** compatible with
        # pipeline_model_parallel_size>1. So use pipeline_model_parallel_size as
        # tensor_model_parallel_size and set pipeline_model_parallel_size to 1.
        parallel_state.initialize_model_parallel(1, 1, None)
        _reconfigure_microbatch_calculator(
            args.rank,
            args.rampup_batch_size,
            args.global_batch_size,
            args.micro_batch_size,
            parallel_state.get_data_parallel_world_size(),
        )
    else:
        # NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is necessary to enable interleaving scheduling
        # In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and
        # used ubiquitously but this test uses custom model so it's safe to abuse.
        parallel_state.initialize_model_parallel(
            data_parallel_size, pipeline_model_parallel_size,
            virtual_pipeline_model_parallel_size)
        _reconfigure_microbatch_calculator(
            args.rank,
            args.rampup_batch_size,
            args.global_batch_size,
            args.micro_batch_size,
            parallel_state.get_data_parallel_world_size(),
        )
        if virtual_pipeline_model_parallel_size is not None:
            # Check the experimental warning message
            get_forward_backward_func(virtual_pipeline_model_parallel_size,
                                      pipeline_model_parallel_size)
    pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size(
    )

    model = build_model(
        model_provider_func,
        wrap_with_ddp=True,
        virtual_pipeline_model_parallel_size=
        virtual_pipeline_model_parallel_size,
        hidden_size=hidden_size,
    )
    assert isinstance(model, list)
    assert len(model) == (1 if virtual_pipeline_model_parallel_size is None
                          else virtual_pipeline_model_parallel_size)
    _param_groups = _get_params_for_weight_decay_optimization(model)
    torch.optim.Adam(_param_groups, lr=1e-4)

    tensor_shape = [
        batch_size // parallel_state.get_data_parallel_world_size(),
        hidden_size, hidden_size
    ]
    batch = (torch.randn(tensor_shape).cuda(), )
    tensor_shape[0] = micro_batch_size

    update_num_microbatches(0)
    forward_backward_func(
        fwd_step_func,
        batch,
        model,
        forward_only=forward_only,
        tensor_shape=tensor_shape,
        dtype=dtype,
        grad_scaler=grad_scaler,
        deallocate_pipeline_outputs=deallocate_pipeline_outputs,
    )

    if not forward_only:
        for m in model:
            for p in m.parameters():
                if p.grad is None:
                    raise RuntimeError("grad not found")
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
Esempio n. 18
0
    def eval_step(self, batch, batch_idx, dataloader_idx):
        # Need to squeze dim 0 for tarred datasets since things are pre-batched and we ask the dataloader for batch size 1.
        batch = [[x.squeeze(dim=0) if x.ndim == 3 else x for x in microbatch]
                 for microbatch in batch]
        batch = self.process_global_batch_for_tarred_datasets(batch)
        app_state = AppState()
        _reconfigure_microbatch_calculator(
            rank=app_state.global_rank,
            rampup_batch_size=None,
            global_batch_size=batch['text_enc'].size(0) *
            parallel_state.get_data_parallel_world_size(),
            micro_batch_size=batch['text_enc'].size(0),
            data_parallel_size=parallel_state.get_data_parallel_world_size(),
        )
        # This returns the averaged loss across data-parallel groups.
        reduced_loss = super().validation_step(batch, batch_idx)
        tokens_enc, labels, enc_mask = batch['text_enc'], batch[
            'labels'], batch['enc_mask']
        predicted_tokens_ids, _ = self.decode(
            tokens_enc,
            enc_mask,
            tokens_enc.size(1) + self._cfg.
            max_generation_delta,  # Generate up to src-length + max generation delta. TODO: Implement better stopping when everything hits <EOS>.
            tokenizer=self.decoder_tokenizer,
        )

        # Post-process the translations and inputs to log.

        # Convert ids to lists.
        preds = predicted_tokens_ids.cpu().numpy().tolist()
        labels = labels.cpu().numpy().tolist()
        encoder_inputs = tokens_enc.cpu().numpy().tolist()

        # Filter out the special tokens and de-tokenize.
        inputs = []
        translations = []
        ground_truths = []
        for _, (pred, label,
                input) in enumerate(zip(preds, labels, encoder_inputs)):
            if self.decoder_tokenizer.eos_id in pred:
                idx = pred.index(self.decoder_tokenizer.eos_id)
                pred = pred[:idx]

            # Legacy sentencepiece detokenization still preserves special tokens which messes up exact string match.
            if hasattr(self.decoder_tokenizer, 'special_token_to_id'):
                pred = [
                    id for id in pred if id not in
                    self.decoder_tokenizer.special_token_to_id.values()
                ]
                label = [
                    id for id in label if id not in
                    self.decoder_tokenizer.special_token_to_id.values()
                ]

            if hasattr(self.encoder_tokenizer, 'special_token_to_id'):
                input = [
                    id for id in input if id not in
                    self.encoder_tokenizer.special_token_to_id.values()
                ]

            pred = self.decoder_tokenizer.ids_to_text(pred)
            label = self.decoder_tokenizer.ids_to_text(label)
            input = self.encoder_tokenizer.ids_to_text(input)
            translations.append(pred)
            ground_truths.append(label)
            inputs.append(input)

        if self.multilingual:
            self.source_processor = self.source_processor_list[dataloader_idx]
            self.target_processor = self.target_processor_list[dataloader_idx]

        # De-tokenize inputs, translations and ground truths.
        if self.target_processor is not None:
            ground_truths = [
                self.target_processor.detokenize(tgt.split(' '))
                for tgt in ground_truths
            ]
            translations = [
                self.target_processor.detokenize(translation.split(' '))
                for translation in translations
            ]

        if self.source_processor is not None:
            inputs = [
                self.source_processor.detokenize(src.split(' '))
                for src in inputs
            ]

        return {
            'inputs': inputs,
            'translations': translations,
            'ground_truths': ground_truths,
            'loss': reduced_loss,
        }
Esempio n. 19
0
    def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, encoder_input=None, tokenizer=None):
        # Check whether the DDP is initialized. This is needed when running inference outside of training loop.
        if parallel_state.is_unitialized():

            def dummy():
                return

            if self.trainer.strategy.launcher is not None:
                self.trainer.strategy.launcher.launch(dummy, trainer=self.trainer)
            self.trainer.strategy.setup_environment()

            # Reconfigure microbatch sizes here because on model restore, this will contain the micro/global batch configuration used while training.
            _reconfigure_microbatch_calculator(
                rank=0,  # This doesn't matter since it is only used for logging
                rampup_batch_size=None,
                global_batch_size=1,
                micro_batch_size=1,  # Make sure that there is no "grad acc" while decoding.
                data_parallel_size=1,  # We check above to make sure that dataparallel size is always 1 at inference.
            )

        # If classes that inherit from this class are using a different tokenizer,
        tokenizer = self.tokenizer if tokenizer is None else tokenizer
        app_state = AppState()
        global_batch_per_gpu = tokens_enc.size(0)

        num_micro_batches_before_decode = get_num_microbatches()
        # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc".
        # TODO: reconfigure back to how things were before decode?
        _reconfigure_microbatch_calculator(
            rank=app_state.global_rank,
            rampup_batch_size=None,
            global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(),
            micro_batch_size=global_batch_per_gpu,  # Make sure that there is no "grad acc" while decoding.
            data_parallel_size=parallel_state.get_data_parallel_world_size(),
        )
        predicted_tokens_dec = (
            torch.LongTensor([tokenizer.bos_id] * global_batch_per_gpu).unsqueeze(1).to(tokens_enc.device)
        )
        encoder_seq_length = tokens_enc.size(1)
        tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size]
        assert predicted_tokens_dec.size(0) == global_batch_per_gpu

        for i in range(num_tokens_to_generate):
            # No microbatches in decoding. Just the global batch.
            decoder_seq_length = predicted_tokens_dec.size(1)
            dec_mask = predicted_tokens_dec != tokenizer.pad_id

            if encoder_input is not None:
                batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask, encoder_input]
            else:
                batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask]
            if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
                output_tensor = forward_backward_pipelining_without_interleaving(
                    forward_step_func=self.get_forward_output_only_func(),
                    batch=batch_for_pipeline,
                    model=self.enc_dec_model,
                    forward_only=True,
                    tensor_shape=tensor_shape,
                    decoder_sequence_length=decoder_seq_length,
                    dtype=self.autocast_dtype,
                )
            else:
                output_tensor = forward_backward_no_pipelining(
                    forward_step_func=self.get_forward_output_only_func(),
                    batch=batch_for_pipeline,
                    model=self.enc_dec_model,
                    forward_only=True,
                    tensor_shape=tensor_shape,
                    decoder_sequence_length=decoder_seq_length,
                    dtype=self.autocast_dtype,
                )

            # get output tensor
            if parallel_state.is_pipeline_last_stage():
                output_tensor = output_tensor[0]['logits']
                output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor)
                log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(output_tensor, dim=-1), dim=-1)
                predicted_tokens_dec = torch.cat(
                    [predicted_tokens_dec.to(token_ids.device), token_ids[:, -1].unsqueeze(1)], dim=1
                )
            else:
                log_probs = torch.zeros(
                    (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1]), dtype=self.autocast_dtype
                ).cuda()
                predicted_tokens_dec = torch.zeros(
                    (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1),
                    dtype=predicted_tokens_dec.dtype,
                ).cuda()

            if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
                # Broadcast from the last pipeline stage to all other model-parallel ranks.
                torch.distributed.broadcast(
                    predicted_tokens_dec,
                    parallel_state.get_pipeline_model_parallel_last_rank(),
                    group=parallel_state.get_model_parallel_group(),
                )
                torch.distributed.broadcast(
                    log_probs,
                    parallel_state.get_pipeline_model_parallel_last_rank(),
                    group=parallel_state.get_model_parallel_group(),
                )

        # Reset microbatch calculator to what it was before decoding.
        _reconfigure_microbatch_calculator(
            rank=app_state.global_rank,
            rampup_batch_size=None,
            global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(),
            micro_batch_size=global_batch_per_gpu // num_micro_batches_before_decode,
            data_parallel_size=parallel_state.get_data_parallel_world_size(),
        )
        return predicted_tokens_dec, log_probs
Esempio n. 20
0
def run_interleaved_with_dynamic_batch_size(
    pipeline_model_parallel_size: int,
    forward_only: bool,
    BatchSamplerCls,
) -> None:
    args = global_vars.get_args()
    _reconfigure_microbatch_calculator(
        args.rank,
        args.rampup_batch_size,
        args.global_batch_size,
        args.micro_batch_size,
        1,  # args.data_parallel_size,
    )
    virtual_pipeline_model_parallel_size = 2
    # NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is a requisite for the interleaving scheduling
    # In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and
    # used ubiquitously but this test uses custom model so it's safe to abuse.
    parallel_state.initialize_model_parallel(
        1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size)
    pipeline_model_parallel_size = (
        parallel_state.get_pipeline_model_parallel_world_size())

    print_separator(
        f"BatchSamplerCls: {BatchSamplerCls.__name__}, forward_only: {forward_only}"
    )

    model = build_model(
        model_provider_func,
        wrap_with_ddp=True,
        virtual_pipeline_model_parallel_size=
        virtual_pipeline_model_parallel_size,
        hidden_size=HIDDEN_SIZE,
    )
    assert isinstance(model, list)
    assert len(model) == virtual_pipeline_model_parallel_size
    optimizer = torch.optim.Adam(
        _get_params_for_weight_decay_optimization(model))

    initial_local_minibatch_size = get_num_microbatches() * micro_batch_size
    dataset = Dataset(NUM_SAMPLES)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=BatchSamplerCls(
            NUM_SAMPLES,
            0,
            initial_local_minibatch_size,
            parallel_state.get_data_parallel_rank(),
            parallel_state.get_data_parallel_world_size(),
        ),
    )
    data_iter = iter(data_loader)

    def get_num_samples(batch):
        if isinstance(batch, torch.Tensor):
            return len(batch)
        assert isinstance(batch, (list, tuple))
        return [get_num_samples(b) for b in batch]

    tensor_shape = [micro_batch_size, HIDDEN_SIZE, HIDDEN_SIZE]
    consumed_samples = 0
    for i in range(NUM_ITERATIONS):
        update_num_microbatches(consumed_samples, consistency_check=False)
        local_batch_size = get_num_microbatches() * micro_batch_size
        data_iter._index_sampler.local_minibatch_size = local_batch_size
        local_mini_batch = next(data_iter)

        _logger.info(f"iter: {i} / {NUM_ITERATIONS} "
                     f"local batchsize: {get_num_samples(local_mini_batch)} "
                     f"consumed_samples: {consumed_samples} / {NUM_SAMPLES}")
        _forward_backward_pipelining_with_interleaving(
            fwd_step_func,
            local_mini_batch,
            model,
            forward_only=forward_only,
            tensor_shape=tensor_shape,
        )

        consumed_samples += (parallel_state.get_data_parallel_world_size() *
                             get_num_microbatches() * micro_batch_size)

        if not forward_only:
            for m in model:
                for p in m.parameters():
                    if p.grad is None:
                        raise RuntimeError("grad not found")
            else:
                optimizer.zero_grad(set_to_none=True)

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)