def allreduce_word_and_position_embeddings(self):

        # Modified from megatron-lm: https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/training.py#L407
        # All-reduce word_embeddings' grad across first and last stages to ensure
        # that word_embeddings parameters stay in sync.
        # This should only run for models that support pipelined model parallelism
        # (BERT and GPT-2).
        if parallel_state.get_pipeline_model_parallel_world_size() > 1 and (
            parallel_state.is_rank_in_embedding_group()
        ):
            if self.enc_dec_model.share_word_embeddings:
                word_embeddings_weight = self.enc_dec_model.word_embeddings_weight()
                if self.megatron_amp_o2:
                    # O2 recipe stores a "main" copy of weights and grads
                    grad = word_embeddings_weight.main_grad
                else:
                    grad = word_embeddings_weight.grad
                torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())

                # All reduce position embeddings for T5.
                if (
                    parallel_state.is_rank_in_position_embedding_group()
                    and parallel_state.get_pipeline_model_parallel_world_size() > 1
                    and parallel_state.get_pipeline_model_parallel_split_rank() is not None
                ):
                    position_embeddings_weight = self.enc_dec_model.position_embeddings_weight()
                    if self.megatron_amp_o2:
                        grad = position_embeddings_weight.main_grad
                    else:
                        grad = position_embeddings_weight.grad
                    torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group())
Exemplo n.º 2
0
    def initialize_word_embeddings(self,
                                   init_method,
                                   vocab_size,
                                   hidden_size,
                                   pipeline_model_parallel_size=1):
        if not self.share_word_embeddings:
            raise Exception('initialize_word_embeddings() was called but '
                            'share_word_embeddings is false')

        # TODO: pipeline model parallelism is not implemented in NeMo yet
        # This function just initializes the word embeddings in the final stage
        # when we are using pipeline parallelism. If we aren't using pipeline
        # parallelism there is nothing to do.
        if pipeline_model_parallel_size == 1:
            return

        # Parameters are shared between the word embeddings layer, and the
        # heads at the end of the model. In a pipelined setup with more than
        # one stage, the initial embedding layer and the head are on different
        # workers, so we do the following:
        # 1. Create a second copy of word_embeddings on the last stage, with
        #    initial parameters of 0.0.
        # 2. Do an all-reduce between the first and last stage to ensure that
        #    the two copies of word_embeddings start off with the same
        #    parameter values.
        # 3. In the training loop, before an all-reduce between the grads of
        #    the two word_embeddings layers to ensure that every applied weight
        #    update is the same on both stages.
        if parallel_state.is_pipeline_last_stage():
            assert not parallel_state.is_pipeline_first_stage()
            self._word_embeddings_for_head_key = 'word_embeddings_for_head'
            # set word_embeddings weights to 0 here, then copy first
            # stage's weights using all_reduce below.
            self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
                vocab_size, hidden_size, init_method=init_method)
            self.word_embeddings.weight.data.fill_(0)
            self.word_embeddings.weight.shared = True

        # Ensure that first and last stages have the same initial parameter
        # values.
        if torch.distributed.is_initialized():
            if parallel_state.is_pipeline_first_stage(
            ) or parallel_state.is_pipeline_last_stage():
                torch.distributed.all_reduce(
                    self.word_embeddings_weight().data,
                    group=parallel_state.get_embedding_group())
        else:
            print("WARNING! Distributed processes aren't initialized, so "
                  "word embeddings in the last layer are not initialized. "
                  "If you are just manipulating a model this is fine, but "
                  "this needs to be handled manually. If you are training "
                  "something is definitely wrong.")
Exemplo n.º 3
0
    def sync_initial_word_embeddings(self):

        if torch.distributed.is_initialized():
            if parallel_state.is_pipeline_first_stage(
            ) or parallel_state.is_pipeline_last_stage():
                torch.distributed.all_reduce(
                    self.word_embeddings_weight().data,
                    group=parallel_state.get_embedding_group())
        else:
            logging.warning(
                "WARNING! Distributed processes aren't initialized, so "
                "word embeddings in the last layer are not synchronized. "
                "If you are just manipulating a model this is fine, but "
                "this needs to be handled manually. If you are training "
                "something is definitely wrong.")
Exemplo n.º 4
0
def tab_sample_sequence_batch(
    model,
    context_tokens,
    context_lengths,
    attention_mask,
    position_ids,
    tokens_to_generate,
    all_probs=True,
    type_ids=None,
    temperature=None,
):
    app_state = AppState()
    micro_batch_size = context_tokens.shape[0]
    _reconfigure_microbatch_calculator(
        rank=app_state.global_rank,
        rampup_batch_size=None,
        global_batch_size=micro_batch_size,
        micro_batch_size=micro_batch_size,
        data_parallel_size=1,
    )
    tokenizer = model.tokenizer
    sizes = tokenizer.code_column.sizes
    tokens_per_row = sum(sizes) + 1
    columns = tokenizer.code_column.columns
    num_columns = len(columns)
    tokenid_range = []
    for i in range(num_columns):
        tokenid_range.extend(tokenizer.code_column.get_range(i))

    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
        context = context_tokens[:, :context_length]
        # the context may start in the middle of the row,
        # calculate the offset according to the position of '\n' or '<|endoftext|>'
        positions = torch.where(context == tokenizer.eor)[1]
        if len(positions) == 0:
            positions = torch.where(context == tokenizer.eod)[1]
        if len(positions) != 0:
            max_position = positions.max().item()
            # TODO, need to make sure context of different batch have the same offset lengths")
            # otherwise, need to calculate offset per batch_id
            offset = (context_length - max_position - 1) % tokens_per_row
        else:
            offset = 0

        eod_id = tokenizer.eos_id

        counter = 0

        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        output_logits = None

        # Generate enough tokens for the longest sequence
        maxlen = tokens_to_generate + context_lengths.max().item()

        if maxlen > model.cfg.encoder_seq_length:
            maxlen = model.cfg.encoder_seq_length

        lengths = torch.ones([batch_size]).long().cuda() * maxlen

        while context_length < maxlen:
            # types2use = None
            if counter == 0:
                # Allocate memory for the entire context.
                set_inference_key_value_memory = True
                tokens2use = tokens[:, :context_length]
                positions2use = position_ids[:, :context_length]
                # not using type2use. uncomment it if it is used
                # if type_ids is not None:
                #     types2use = type_ids[:, :context_length]
            else:
                # Set this to false so the memory is not reallocated.
                set_inference_key_value_memory = False
                tokens2use = tokens[:, context_length - 1].view(batch_size, -1)
                positions2use = position_ids[:, context_length - 1].view(
                    batch_size, -1)
                # not using type2use. uncomment it if it is used
                # if type_ids is not None:
                #     types2use = type_ids[:, context_length - 1].view(batch_size, -1)
            # micro_batch_size = 2
            attention_mask_repeat = torch.concat(
                [attention_mask for _ in range(micro_batch_size)])
            setkey_value_array = torch.tensor(
                [set_inference_key_value_memory] * micro_batch_size,
                device=torch.cuda.current_device())
            len_array = torch.tensor([maxlen] * micro_batch_size,
                                     device=torch.cuda.current_device())
            batch = [
                tokens2use, attention_mask_repeat, positions2use,
                setkey_value_array, len_array
            ]
            tensor_shape = [
                tokens2use.shape[1], micro_batch_size, model.cfg.hidden_size
            ]

            output = forward_step(model, batch, tensor_shape)

            if parallel_state.is_pipeline_last_stage():
                output = output[0]['logits'].float()
                output = tensor_parallel.gather_from_tensor_model_parallel_region(
                    output)
                assert output is not None
                output = output.float()
                logits = output[:, -1].view(batch_size, -1).contiguous()
                token_in_row = (counter + offset) % tokens_per_row
                logits = logits.float()
                logits /= temperature
                if token_in_row == tokens_per_row - 1:
                    # line break
                    eor_id = tokenizer.eor
                    eod_id = tokenizer.eos_id
                    min_id = min(eor_id, eod_id)
                    max_id = max(eor_id, eod_id) + 1
                    logits = tab_logits(logits, min_id, max_id)
                else:
                    # limit the range
                    min_id, max_id = tokenid_range[token_in_row]
                    logits = tab_logits(logits, min_id, max_id)
                log_probs = F.softmax(logits, dim=-1)
                prev = torch.multinomial(log_probs, num_samples=1).view(-1)
                started = context_lengths <= context_length
                # Clamp the out of vocabulary tokens.
                prev = torch.clamp(prev, max=tokenizer.vocab_size - 1)

                new_tokens = switch(tokens[:, context_length].view(-1), prev,
                                    started)
                tokens[:, context_length] = new_tokens

                if output_logits is None:
                    output_context = F.log_softmax(
                        output[:, :context_length, :], 2)
                    indices = torch.unsqueeze(tokens[:, 1:context_length + 1],
                                              2)
                    output_logits = torch.gather(output_context, 2,
                                                 indices).squeeze(2)
                    if all_probs:
                        full_logits = output_context
                else:
                    output_context = F.log_softmax(output, 2)
                    indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2)
                    new_output_logits = torch.gather(output_context, 2,
                                                     indices).squeeze(2)

                    # TODO(rprenger) we're copying output_logits every time.  Should pre-allocate
                    output_logits = torch.cat(
                        [output_logits, new_output_logits], 1)
                    if all_probs:
                        full_logits = torch.cat([full_logits, output_context],
                                                1)

                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eod_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                if all_probs:
                    yield tokens, lengths, output_logits, full_logits
                else:
                    yield tokens, lengths, output_logits, None

            else:
                if parallel_state.is_pipeline_first_stage():
                    src = parallel_state.get_pipeline_model_parallel_last_rank(
                    )
                    group = parallel_state.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None, None, None
                else:
                    yield None, None, None, None

                done = torch.cuda.ByteTensor([0])
                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)

            context_length += 1
            counter += 1
            if done:
                break
Exemplo n.º 5
0
def sample_sequence_batch(
    model,
    context_tokens,
    context_lengths,
    task_ids,
    attention_mask,
    position_ids,
    tokens_to_generate,
    all_probs=False,
    type_ids=None,
    temperature=None,
    extra={},
):
    # Importing here to avoid circular import errors
    from nemo.collections.nlp.models.language_modeling import MegatronGPTPromptLearningModel

    app_state = AppState()
    micro_batch_size = context_tokens.shape[0]
    _reconfigure_microbatch_calculator(
        rank=app_state.global_rank,
        rampup_batch_size=None,
        global_batch_size=micro_batch_size,
        micro_batch_size=micro_batch_size,
        data_parallel_size=1,
    )
    tokenizer = model.tokenizer
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()

        # added eos_id to support the function generate_samples_eval that passes
        # eos_id as an argument and needs termination when that id id found.
        eod_id = tokenizer.eos_id
        counter = 0

        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        output_logits = None
        all_generated_indices = None  # used to track all generated indices
        # Generate enough tokens for the longest sequence
        maxlen = tokens_to_generate + context_lengths.max().item()

        if maxlen > model.cfg.encoder_seq_length + 1:
            maxlen = model.cfg.encoder_seq_length + 1

        lengths = torch.ones([batch_size]).long().cuda() * maxlen

        while context_length < maxlen:
            # types2use = None
            if counter == 0:
                # Allocate memory for the entire context.
                set_inference_key_value_memory = True
                tokens2use = tokens[:, :context_length]
                positions2use = position_ids[:, :context_length]
                # not using type2use. uncomment it if it is used
                # if type_ids is not None:
                #     types2use = type_ids[:, :context_length]
            else:
                # Set this to false so the memory is not reallocated.
                set_inference_key_value_memory = False
                tokens2use = tokens[:, context_length - 1].view(batch_size, -1)
                positions2use = position_ids[:, context_length - 1].view(
                    batch_size, -1)
                # not using type2use. uncomment it if it is used
                # if type_ids is not None:
                #     types2use = type_ids[:, context_length - 1].view(batch_size, -1)

            attention_mask_repeat = torch.concat(
                [attention_mask for _ in range(micro_batch_size)])
            setkey_value_array = torch.tensor(
                [set_inference_key_value_memory] * micro_batch_size,
                device=torch.cuda.current_device())
            len_array = torch.tensor([maxlen] * micro_batch_size,
                                     device=torch.cuda.current_device())

            # Only prompt learning models will have a prompt table, and require task ids
            if isinstance(model, MegatronGPTPromptLearningModel):
                batch = [
                    tokens2use, attention_mask_repeat, positions2use, task_ids,
                    setkey_value_array, len_array
                ]
                tensor_shape = [
                    tokens2use.shape[1], micro_batch_size,
                    model.frozen_model.cfg.hidden_size
                ]
            else:
                batch = [
                    tokens2use, attention_mask_repeat, positions2use,
                    setkey_value_array, len_array
                ]
                tensor_shape = [
                    tokens2use.shape[1], micro_batch_size,
                    model.cfg.hidden_size
                ]

            output = forward_step(model, batch, tensor_shape)

            if parallel_state.is_pipeline_last_stage():
                output = output[0]['logits'].float()
                output = tensor_parallel.gather_from_tensor_model_parallel_region(
                    output)
                assert output is not None
                output = output.float()
                logits = output[:, -1].view(batch_size, -1).contiguous()

                # make sure it will generate at least min_length
                min_length = extra.get('min_tokens_to_generate', 0)
                if min_length > 0:
                    within_min_length = (context_length -
                                         context_lengths) < min_length
                    logits[within_min_length, eod_id] = -float('Inf')

                # make sure it won't sample outside the vocab_size range
                logits[:, tokenizer.vocab_size:] = -float('Inf')

                if extra.get('greedy', False):
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
                    logits = logits.float()
                    logits /= temperature
                    # handle repetition penality
                    logits = repetition_penalty(
                        logits, extra.get('repetition_penalty', 1.2),
                        all_generated_indices)
                    logits = top_k_logits(logits,
                                          top_k=extra.get('top_k', 0),
                                          top_p=extra.get('top_p', 0.9))
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)
                started = context_lengths <= context_length

                # Clamp the predicted out of vocabulary tokens
                prev = torch.clamp(prev, max=tokenizer.vocab_size - 1)
                new_tokens = switch(tokens[:, context_length].view(-1), prev,
                                    started)

                # Replace sampled tokens w/ done token if EOD has already been sampled
                new_tokens = switch(new_tokens, eod_id, is_done)

                # Replace special soft prompt token ids with unk token ids
                if isinstance(model, MegatronGPTPromptLearningModel):
                    pseudo_token_ids_start = model.pseudo_token_ids_start
                    new_tokens[(new_tokens >=
                                pseudo_token_ids_start)] = tokenizer.unk_id
                    tokens[:, :context_length][(
                        tokens[:, :context_length] >=
                        pseudo_token_ids_start)] = tokenizer.unk_id

                # Insert either new predicted or next prompt token
                tokens[:, context_length] = new_tokens

                if output_logits is None:
                    output = F.log_softmax(output[:, :context_length, :], 2)
                    indices = torch.unsqueeze(tokens[:, 1:context_length + 1],
                                              2)
                    output_logits = torch.gather(output, 2, indices).squeeze(2)
                    all_generated_indices = indices[:, :, 0]
                    if all_probs:
                        full_logits = output
                else:
                    output = F.log_softmax(output, 2)
                    indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2)
                    new_output_logits = torch.gather(output, 2,
                                                     indices).squeeze(2)

                    # TODO(rprenger) we're copying output_logits every time.  Should pre-allocate
                    output_logits = torch.cat(
                        [output_logits, new_output_logits], 1)
                    all_generated_indices = torch.cat(
                        [all_generated_indices, indices[:, :, 0]], 1)
                    if all_probs:
                        full_logits = torch.cat([full_logits, output], 1)

                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eod_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                if all_probs:
                    yield tokens, lengths, output_logits, full_logits
                else:
                    yield tokens, lengths, output_logits, None

            else:
                if parallel_state.is_pipeline_first_stage():
                    src = parallel_state.get_pipeline_model_parallel_last_rank(
                    )
                    group = parallel_state.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None, None, None
                else:
                    yield None, None, None, None

                done = torch.cuda.ByteTensor([0])
                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)

            context_length += 1
            counter += 1
            if done:
                break
Exemplo n.º 6
0
def synced_generate(
    model,
    context_tokens_tensor,
    context_length_tensor,
    task_ids,
    tokens_to_generate,
    all_probs,
    temperature,
    top_k=0,
    top_p=0.0,
    greedy=False,
    repetition_penalty=1.2,
    min_tokens_to_generate=0,
):
    context_length = context_length_tensor.min().item()
    tokenizer = model.tokenizer
    tokens, attention_mask, position_ids = get_batch(model, tokenizer,
                                                     context_tokens_tensor)
    if isinstance(tokenizer, TabularTokenizer):
        batch_token_iterator = tab_sample_sequence_batch(
            model,
            context_tokens_tensor,
            context_length_tensor,
            attention_mask,
            position_ids,
            tokens_to_generate,
            all_probs,
            temperature=temperature,
        )
    else:
        batch_token_iterator = sample_sequence_batch(
            model,
            context_tokens_tensor,
            context_length_tensor,
            task_ids,
            attention_mask,
            position_ids,
            tokens_to_generate,
            all_probs,
            temperature=temperature,
            extra={
                "top_p": top_p,
                "top_k": top_k,
                "greedy": greedy,
                "repetition_penalty": repetition_penalty,
                "min_tokens_to_generate": min_tokens_to_generate,
            },
        )

    for tokens, lengths, output_logits, full_logits in batch_token_iterator:
        context_length += 1

    if parallel_state.is_pipeline_last_stage():
        src = parallel_state.get_pipeline_model_parallel_last_rank()
        group = parallel_state.get_embedding_group()
        torch.distributed.broadcast(output_logits, src, group)
        if all_probs:
            src = parallel_state.get_pipeline_model_parallel_last_rank()
            group = parallel_state.get_embedding_group()
            torch.distributed.broadcast(full_logits, src, group)

    else:
        if parallel_state.is_pipeline_first_stage():
            src = parallel_state.get_pipeline_model_parallel_last_rank()
            group = parallel_state.get_embedding_group()
            output_logits = torch.empty(tokens.size(0),
                                        context_length - 1,
                                        dtype=torch.float32,
                                        device=torch.device("cuda"))
            torch.distributed.broadcast(output_logits, src, group)

            if all_probs:
                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_embedding_group()
                full_logits = torch.empty(
                    tokens.size(0),
                    context_length - 1,
                    model.padded_vocab_size,
                    dtype=torch.float32,
                    device=torch.device("cuda"),
                )
                torch.distributed.broadcast(full_logits, src, group)
    if tokens is not None:
        return tokens[:, :context_length], output_logits, full_logits