Ejemplo n.º 1
0
    def decode(self, tokens_enc, enc_mask, num_tokens_to_generate):
        encoder_hidden_states = self(
            encoder_input_ids=tokens_enc,
            decoder_input_ids=None,
            encoder_attn_mask=enc_mask,
            decoder_attn_mask=None,
            encoder_decoder_attn_mask=None,
            tokentype_ids=None,
            lm_labels=None,
            enc_hidden_states=None,
            output_enc_hidden_only=True,
        )
        predicted_tokens_dec = torch.LongTensor(
            [self.tokenizer.bos_id]).unsqueeze(0).to(tokens_enc.device)

        for _ in range(num_tokens_to_generate):
            # Overwrite the decoder token since we want to predict
            enc_dec_mask = self.make_inference_attention_mask_3d(
                predicted_tokens_dec, tokens_enc, self.tokenizer.pad_id)
            dec_mask = self.make_inference_attention_mask_3d(
                predicted_tokens_dec, predicted_tokens_dec,
                self.tokenizer.pad_id)
            dec_mask = dec_mask * self.make_inference_history_mask_3d(
                predicted_tokens_dec)

            enc_dec_mask = enc_dec_mask < 0.5
            dec_mask = dec_mask < 0.5

            output_tensor, _ = self(
                encoder_input_ids=tokens_enc,
                decoder_input_ids=predicted_tokens_dec,
                encoder_attn_mask=enc_mask,
                decoder_attn_mask=dec_mask,
                encoder_decoder_attn_mask=enc_dec_mask,
                tokentype_ids=None,
                lm_labels=None,
                enc_hidden_states=encoder_hidden_states,
                output_enc_hidden_only=False,
            )
            output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(
                output_tensor)
            log_probs, token_ids = torch.max(nn.functional.log_softmax(
                output_tensor, dim=-1),
                                             dim=-1)
            predicted_tokens_dec = torch.cat(
                [predicted_tokens_dec, token_ids[:, -1].unsqueeze(1)], 1)
            if token_ids[:, -1] == self.tokenizer.eos_id:
                break

        return predicted_tokens_dec, log_probs
Ejemplo n.º 2
0
def parallel_lm_logits(input_,
                       word_embeddings_weight,
                       parallel_output,
                       bias=None):
    """LM logits using word embedding weights."""
    # Parallel logits.
    input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(
        input_)
    # Matrix multiply.
    if bias is None:
        logits_parallel = F.linear(input_parallel, word_embeddings_weight)
    else:
        logits_parallel = F.linear(input_parallel, word_embeddings_weight,
                                   bias)
    # Gather if needed.
    if parallel_output:
        return logits_parallel

    return tensor_parallel.gather_from_tensor_model_parallel_region(
        logits_parallel)
    def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, enc_input=None):
        # TODO: move method into a class inside MegatronTokenLevelEncoderDecoderModule (?)
        encoder_hidden_states, enc_output_mask = itemgetter("enc_output", "enc_output_mask")(
            self(
                encoder_input_ids=tokens_enc,
                decoder_input_ids=None,
                encoder_attn_mask=enc_mask,
                decoder_attn_mask=None,
                tokentype_ids=None,
                lm_labels=None,
                enc_hidden_states=None,
                enc_output_mask=None,
                output_enc_hidden_only=True,
                enc_input=enc_input,
            )
        )
        predicted_tokens_dec = (
            torch.LongTensor([self.tokenizer.bos_id] * tokens_enc.size(0)).unsqueeze(1).to(tokens_enc.device)
        )
        for _ in range(num_tokens_to_generate):
            dec_mask = predicted_tokens_dec != self.tokenizer.pad_id
            token_logits = itemgetter("token_logits")(
                self(
                    encoder_input_ids=tokens_enc,
                    decoder_input_ids=predicted_tokens_dec,
                    encoder_attn_mask=enc_mask,
                    decoder_attn_mask=dec_mask,
                    tokentype_ids=None,
                    lm_labels=None,
                    enc_hidden_states=encoder_hidden_states,
                    enc_output_mask=enc_output_mask,
                    output_enc_hidden_only=False,
                    enc_input=enc_input,
                )
            )
            token_logits = tensor_parallel.gather_from_tensor_model_parallel_region(token_logits)
            log_probs, token_ids = torch.max(nn.functional.log_softmax(token_logits, dim=-1), dim=-1)
            predicted_tokens_dec = torch.cat([predicted_tokens_dec, token_ids[:, -1].unsqueeze(1)], 1)

        return predicted_tokens_dec, log_probs
    def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, encoder_input=None):
        app_state = AppState()
        global_batch_per_gpu = tokens_enc.size(0)

        num_micro_batches_before_decode = get_num_microbatches()
        # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc".
        # TODO: reconfigure back to how things were before decode?
        _reconfigure_microbatch_calculator(
            rank=app_state.global_rank,
            rampup_batch_size=None,
            global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(),
            micro_batch_size=global_batch_per_gpu,  # Make sure that there is no "grad acc" while decoding.
            data_parallel_size=parallel_state.get_data_parallel_world_size(),
        )
        predicted_tokens_dec = (
            torch.LongTensor([self.tokenizer.bos_id] * global_batch_per_gpu).unsqueeze(1).to(tokens_enc.device)
        )
        encoder_seq_length = tokens_enc.size(1)
        tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size]
        assert predicted_tokens_dec.size(0) == global_batch_per_gpu

        for i in range(num_tokens_to_generate):
            # No microbatches in decoding. Just the global batch.
            decoder_seq_length = predicted_tokens_dec.size(1)
            dec_mask = predicted_tokens_dec != self.tokenizer.pad_id

            if encoder_input is not None:
                batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask, encoder_input]
            else:
                batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask]
            if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
                output_tensor = forward_backward_pipelining_without_interleaving(
                    forward_step_func=self.get_forward_output_only_func(),
                    batch=batch_for_pipeline,
                    model=self.enc_dec_model,
                    forward_only=True,
                    tensor_shape=tensor_shape,
                    decoder_sequence_length=decoder_seq_length,
                    dtype=self.autocast_dtype,
                )
            else:
                output_tensor = forward_backward_no_pipelining(
                    forward_step_func=self.get_forward_output_only_func(),
                    batch=batch_for_pipeline,
                    model=self.enc_dec_model,
                    forward_only=True,
                    tensor_shape=tensor_shape,
                    decoder_sequence_length=decoder_seq_length,
                    dtype=self.autocast_dtype,
                )

            # get output tensor
            if parallel_state.is_pipeline_last_stage():
                output_tensor = output_tensor[0]['logits']
                output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor)
                log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(output_tensor, dim=-1), dim=-1)
                predicted_tokens_dec = torch.cat(
                    [predicted_tokens_dec.to(token_ids.device), token_ids[:, -1].unsqueeze(1)], dim=1
                )
            else:
                log_probs = torch.zeros(
                    (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1]), dtype=self.autocast_dtype
                ).cuda()
                predicted_tokens_dec = torch.zeros(
                    (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1),
                    dtype=predicted_tokens_dec.dtype,
                ).cuda()

            if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
                # Broadcast from the last pipeline stage to all other model-parallel ranks.
                torch.distributed.broadcast(
                    predicted_tokens_dec,
                    parallel_state.get_pipeline_model_parallel_last_rank(),
                    group=parallel_state.get_model_parallel_group(),
                )
                torch.distributed.broadcast(
                    log_probs,
                    parallel_state.get_pipeline_model_parallel_last_rank(),
                    group=parallel_state.get_model_parallel_group(),
                )

        # Reset microbatch calculator to what it was before decoding.
        _reconfigure_microbatch_calculator(
            rank=app_state.global_rank,
            rampup_batch_size=None,
            global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(),
            micro_batch_size=global_batch_per_gpu // num_micro_batches_before_decode,
            data_parallel_size=parallel_state.get_data_parallel_world_size(),
        )
        return predicted_tokens_dec, log_probs
Ejemplo n.º 5
0
    def decode(self, enc_query, enc_taskname, label_position,
               num_tokens_to_generate):
        with torch.no_grad():
            predicted_tokens_dec = enc_query

            label_start = label_position[:, 0].clone()

            for _ in range(num_tokens_to_generate):
                attn_mask = make_attention_mask_3d(predicted_tokens_dec,
                                                   predicted_tokens_dec,
                                                   self.pad_token_id)
                attn_mask = attn_mask * make_history_mask_3d(
                    predicted_tokens_dec)

                attn_mask = attn_mask < 0.5

                attn_mask = attn_mask.unsqueeze(1)

                input_embeds = self.embed_input(predicted_tokens_dec,
                                                enc_taskname)

                encoder_position_ids = build_position_ids(predicted_tokens_dec)
                position_embeddings = self.model.model.language_model.embedding.position_embeddings(
                    encoder_position_ids)

                encoder_input = input_embeds + position_embeddings

                if self.float_type == torch.float32:
                    output = self.model.model(
                        None,
                        None,
                        encoder_input=encoder_input,
                        attention_mask=attn_mask,
                    )
                else:
                    with torch.autocast(device_type="cuda",
                                        dtype=self.float_type):
                        output = self.model.model(
                            None,
                            None,
                            encoder_input=encoder_input,
                            attention_mask=attn_mask,
                        )
                output_tensor = output

                output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(
                    output_tensor)

                # TODO, add logic to use the allowed labels if it is defined
                log_probs, token_ids = torch.max(nn.functional.log_softmax(
                    output_tensor, dim=-1),
                                                 dim=-1)

                new_pred = torch.full_like(token_ids[:, 0:1],
                                           self.pad_token_id)
                predicted_tokens_dec = torch.cat(
                    [predicted_tokens_dec, new_pred], 1)

                predicted = torch.gather(token_ids, 1, label_start.view(-1, 1))

                # need to scatter the token id at the right position
                label_start += 1
                predicted_tokens_dec.scatter_(1, label_start.view(-1, 1),
                                              predicted)

        return predicted_tokens_dec, log_probs
Ejemplo n.º 6
0
    def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, encoder_input=None, tokenizer=None):
        # Check whether the DDP is initialized. This is needed when running inference outside of training loop.
        if parallel_state.is_unitialized():

            def dummy():
                return

            if self.trainer.strategy.launcher is not None:
                self.trainer.strategy.launcher.launch(dummy, trainer=self.trainer)
            self.trainer.strategy.setup_environment()

            # Reconfigure microbatch sizes here because on model restore, this will contain the micro/global batch configuration used while training.
            _reconfigure_microbatch_calculator(
                rank=0,  # This doesn't matter since it is only used for logging
                rampup_batch_size=None,
                global_batch_size=1,
                micro_batch_size=1,  # Make sure that there is no "grad acc" while decoding.
                data_parallel_size=1,  # We check above to make sure that dataparallel size is always 1 at inference.
            )

        # If classes that inherit from this class are using a different tokenizer,
        tokenizer = self.tokenizer if tokenizer is None else tokenizer
        app_state = AppState()
        global_batch_per_gpu = tokens_enc.size(0)

        num_micro_batches_before_decode = get_num_microbatches()
        # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc".
        # TODO: reconfigure back to how things were before decode?
        _reconfigure_microbatch_calculator(
            rank=app_state.global_rank,
            rampup_batch_size=None,
            global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(),
            micro_batch_size=global_batch_per_gpu,  # Make sure that there is no "grad acc" while decoding.
            data_parallel_size=parallel_state.get_data_parallel_world_size(),
        )
        predicted_tokens_dec = (
            torch.LongTensor([tokenizer.bos_id] * global_batch_per_gpu).unsqueeze(1).to(tokens_enc.device)
        )
        encoder_seq_length = tokens_enc.size(1)
        tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size]
        assert predicted_tokens_dec.size(0) == global_batch_per_gpu

        for i in range(num_tokens_to_generate):
            # No microbatches in decoding. Just the global batch.
            decoder_seq_length = predicted_tokens_dec.size(1)
            dec_mask = predicted_tokens_dec != tokenizer.pad_id

            if encoder_input is not None:
                batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask, encoder_input]
            else:
                batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask]
            if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
                output_tensor = forward_backward_pipelining_without_interleaving(
                    forward_step_func=self.get_forward_output_only_func(),
                    batch=batch_for_pipeline,
                    model=self.enc_dec_model,
                    forward_only=True,
                    tensor_shape=tensor_shape,
                    decoder_sequence_length=decoder_seq_length,
                    dtype=self.autocast_dtype,
                )
            else:
                output_tensor = forward_backward_no_pipelining(
                    forward_step_func=self.get_forward_output_only_func(),
                    batch=batch_for_pipeline,
                    model=self.enc_dec_model,
                    forward_only=True,
                    tensor_shape=tensor_shape,
                    decoder_sequence_length=decoder_seq_length,
                    dtype=self.autocast_dtype,
                )

            # get output tensor
            if parallel_state.is_pipeline_last_stage():
                output_tensor = output_tensor[0]['logits']
                output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor)
                log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(output_tensor, dim=-1), dim=-1)
                predicted_tokens_dec = torch.cat(
                    [predicted_tokens_dec.to(token_ids.device), token_ids[:, -1].unsqueeze(1)], dim=1
                )
            else:
                log_probs = torch.zeros(
                    (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1]), dtype=self.autocast_dtype
                ).cuda()
                predicted_tokens_dec = torch.zeros(
                    (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1),
                    dtype=predicted_tokens_dec.dtype,
                ).cuda()

            if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
                # Broadcast from the last pipeline stage to all other model-parallel ranks.
                torch.distributed.broadcast(
                    predicted_tokens_dec,
                    parallel_state.get_pipeline_model_parallel_last_rank(),
                    group=parallel_state.get_model_parallel_group(),
                )
                torch.distributed.broadcast(
                    log_probs,
                    parallel_state.get_pipeline_model_parallel_last_rank(),
                    group=parallel_state.get_model_parallel_group(),
                )

        # Reset microbatch calculator to what it was before decoding.
        _reconfigure_microbatch_calculator(
            rank=app_state.global_rank,
            rampup_batch_size=None,
            global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(),
            micro_batch_size=global_batch_per_gpu // num_micro_batches_before_decode,
            data_parallel_size=parallel_state.get_data_parallel_world_size(),
        )
        return predicted_tokens_dec, log_probs
Ejemplo n.º 7
0
def tab_sample_sequence_batch(
    model,
    context_tokens,
    context_lengths,
    attention_mask,
    position_ids,
    tokens_to_generate,
    all_probs=True,
    type_ids=None,
    temperature=None,
):
    app_state = AppState()
    micro_batch_size = context_tokens.shape[0]
    _reconfigure_microbatch_calculator(
        rank=app_state.global_rank,
        rampup_batch_size=None,
        global_batch_size=micro_batch_size,
        micro_batch_size=micro_batch_size,
        data_parallel_size=1,
    )
    tokenizer = model.tokenizer
    sizes = tokenizer.code_column.sizes
    tokens_per_row = sum(sizes) + 1
    columns = tokenizer.code_column.columns
    num_columns = len(columns)
    tokenid_range = []
    for i in range(num_columns):
        tokenid_range.extend(tokenizer.code_column.get_range(i))

    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
        context = context_tokens[:, :context_length]
        # the context may start in the middle of the row,
        # calculate the offset according to the position of '\n' or '<|endoftext|>'
        positions = torch.where(context == tokenizer.eor)[1]
        if len(positions) == 0:
            positions = torch.where(context == tokenizer.eod)[1]
        if len(positions) != 0:
            max_position = positions.max().item()
            # TODO, need to make sure context of different batch have the same offset lengths")
            # otherwise, need to calculate offset per batch_id
            offset = (context_length - max_position - 1) % tokens_per_row
        else:
            offset = 0

        eod_id = tokenizer.eos_id

        counter = 0

        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        output_logits = None

        # Generate enough tokens for the longest sequence
        maxlen = tokens_to_generate + context_lengths.max().item()

        if maxlen > model.cfg.encoder_seq_length:
            maxlen = model.cfg.encoder_seq_length

        lengths = torch.ones([batch_size]).long().cuda() * maxlen

        while context_length < maxlen:
            # types2use = None
            if counter == 0:
                # Allocate memory for the entire context.
                set_inference_key_value_memory = True
                tokens2use = tokens[:, :context_length]
                positions2use = position_ids[:, :context_length]
                # not using type2use. uncomment it if it is used
                # if type_ids is not None:
                #     types2use = type_ids[:, :context_length]
            else:
                # Set this to false so the memory is not reallocated.
                set_inference_key_value_memory = False
                tokens2use = tokens[:, context_length - 1].view(batch_size, -1)
                positions2use = position_ids[:, context_length - 1].view(
                    batch_size, -1)
                # not using type2use. uncomment it if it is used
                # if type_ids is not None:
                #     types2use = type_ids[:, context_length - 1].view(batch_size, -1)
            # micro_batch_size = 2
            attention_mask_repeat = torch.concat(
                [attention_mask for _ in range(micro_batch_size)])
            setkey_value_array = torch.tensor(
                [set_inference_key_value_memory] * micro_batch_size,
                device=torch.cuda.current_device())
            len_array = torch.tensor([maxlen] * micro_batch_size,
                                     device=torch.cuda.current_device())
            batch = [
                tokens2use, attention_mask_repeat, positions2use,
                setkey_value_array, len_array
            ]
            tensor_shape = [
                tokens2use.shape[1], micro_batch_size, model.cfg.hidden_size
            ]

            output = forward_step(model, batch, tensor_shape)

            if parallel_state.is_pipeline_last_stage():
                output = output[0]['logits'].float()
                output = tensor_parallel.gather_from_tensor_model_parallel_region(
                    output)
                assert output is not None
                output = output.float()
                logits = output[:, -1].view(batch_size, -1).contiguous()
                token_in_row = (counter + offset) % tokens_per_row
                logits = logits.float()
                logits /= temperature
                if token_in_row == tokens_per_row - 1:
                    # line break
                    eor_id = tokenizer.eor
                    eod_id = tokenizer.eos_id
                    min_id = min(eor_id, eod_id)
                    max_id = max(eor_id, eod_id) + 1
                    logits = tab_logits(logits, min_id, max_id)
                else:
                    # limit the range
                    min_id, max_id = tokenid_range[token_in_row]
                    logits = tab_logits(logits, min_id, max_id)
                log_probs = F.softmax(logits, dim=-1)
                prev = torch.multinomial(log_probs, num_samples=1).view(-1)
                started = context_lengths <= context_length
                # Clamp the out of vocabulary tokens.
                prev = torch.clamp(prev, max=tokenizer.vocab_size - 1)

                new_tokens = switch(tokens[:, context_length].view(-1), prev,
                                    started)
                tokens[:, context_length] = new_tokens

                if output_logits is None:
                    output_context = F.log_softmax(
                        output[:, :context_length, :], 2)
                    indices = torch.unsqueeze(tokens[:, 1:context_length + 1],
                                              2)
                    output_logits = torch.gather(output_context, 2,
                                                 indices).squeeze(2)
                    if all_probs:
                        full_logits = output_context
                else:
                    output_context = F.log_softmax(output, 2)
                    indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2)
                    new_output_logits = torch.gather(output_context, 2,
                                                     indices).squeeze(2)

                    # TODO(rprenger) we're copying output_logits every time.  Should pre-allocate
                    output_logits = torch.cat(
                        [output_logits, new_output_logits], 1)
                    if all_probs:
                        full_logits = torch.cat([full_logits, output_context],
                                                1)

                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eod_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                if all_probs:
                    yield tokens, lengths, output_logits, full_logits
                else:
                    yield tokens, lengths, output_logits, None

            else:
                if parallel_state.is_pipeline_first_stage():
                    src = parallel_state.get_pipeline_model_parallel_last_rank(
                    )
                    group = parallel_state.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None, None, None
                else:
                    yield None, None, None, None

                done = torch.cuda.ByteTensor([0])
                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)

            context_length += 1
            counter += 1
            if done:
                break
Ejemplo n.º 8
0
def sample_sequence_batch(
    model,
    context_tokens,
    context_lengths,
    task_ids,
    attention_mask,
    position_ids,
    tokens_to_generate,
    all_probs=False,
    type_ids=None,
    temperature=None,
    extra={},
):
    # Importing here to avoid circular import errors
    from nemo.collections.nlp.models.language_modeling import MegatronGPTPromptLearningModel

    app_state = AppState()
    micro_batch_size = context_tokens.shape[0]
    _reconfigure_microbatch_calculator(
        rank=app_state.global_rank,
        rampup_batch_size=None,
        global_batch_size=micro_batch_size,
        micro_batch_size=micro_batch_size,
        data_parallel_size=1,
    )
    tokenizer = model.tokenizer
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()

        # added eos_id to support the function generate_samples_eval that passes
        # eos_id as an argument and needs termination when that id id found.
        eod_id = tokenizer.eos_id
        counter = 0

        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        output_logits = None
        all_generated_indices = None  # used to track all generated indices
        # Generate enough tokens for the longest sequence
        maxlen = tokens_to_generate + context_lengths.max().item()

        if maxlen > model.cfg.encoder_seq_length + 1:
            maxlen = model.cfg.encoder_seq_length + 1

        lengths = torch.ones([batch_size]).long().cuda() * maxlen

        while context_length < maxlen:
            # types2use = None
            if counter == 0:
                # Allocate memory for the entire context.
                set_inference_key_value_memory = True
                tokens2use = tokens[:, :context_length]
                positions2use = position_ids[:, :context_length]
                # not using type2use. uncomment it if it is used
                # if type_ids is not None:
                #     types2use = type_ids[:, :context_length]
            else:
                # Set this to false so the memory is not reallocated.
                set_inference_key_value_memory = False
                tokens2use = tokens[:, context_length - 1].view(batch_size, -1)
                positions2use = position_ids[:, context_length - 1].view(
                    batch_size, -1)
                # not using type2use. uncomment it if it is used
                # if type_ids is not None:
                #     types2use = type_ids[:, context_length - 1].view(batch_size, -1)

            attention_mask_repeat = torch.concat(
                [attention_mask for _ in range(micro_batch_size)])
            setkey_value_array = torch.tensor(
                [set_inference_key_value_memory] * micro_batch_size,
                device=torch.cuda.current_device())
            len_array = torch.tensor([maxlen] * micro_batch_size,
                                     device=torch.cuda.current_device())

            # Only prompt learning models will have a prompt table, and require task ids
            if isinstance(model, MegatronGPTPromptLearningModel):
                batch = [
                    tokens2use, attention_mask_repeat, positions2use, task_ids,
                    setkey_value_array, len_array
                ]
                tensor_shape = [
                    tokens2use.shape[1], micro_batch_size,
                    model.frozen_model.cfg.hidden_size
                ]
            else:
                batch = [
                    tokens2use, attention_mask_repeat, positions2use,
                    setkey_value_array, len_array
                ]
                tensor_shape = [
                    tokens2use.shape[1], micro_batch_size,
                    model.cfg.hidden_size
                ]

            output = forward_step(model, batch, tensor_shape)

            if parallel_state.is_pipeline_last_stage():
                output = output[0]['logits'].float()
                output = tensor_parallel.gather_from_tensor_model_parallel_region(
                    output)
                assert output is not None
                output = output.float()
                logits = output[:, -1].view(batch_size, -1).contiguous()

                # make sure it will generate at least min_length
                min_length = extra.get('min_tokens_to_generate', 0)
                if min_length > 0:
                    within_min_length = (context_length -
                                         context_lengths) < min_length
                    logits[within_min_length, eod_id] = -float('Inf')

                # make sure it won't sample outside the vocab_size range
                logits[:, tokenizer.vocab_size:] = -float('Inf')

                if extra.get('greedy', False):
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
                    logits = logits.float()
                    logits /= temperature
                    # handle repetition penality
                    logits = repetition_penalty(
                        logits, extra.get('repetition_penalty', 1.2),
                        all_generated_indices)
                    logits = top_k_logits(logits,
                                          top_k=extra.get('top_k', 0),
                                          top_p=extra.get('top_p', 0.9))
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)
                started = context_lengths <= context_length

                # Clamp the predicted out of vocabulary tokens
                prev = torch.clamp(prev, max=tokenizer.vocab_size - 1)
                new_tokens = switch(tokens[:, context_length].view(-1), prev,
                                    started)

                # Replace sampled tokens w/ done token if EOD has already been sampled
                new_tokens = switch(new_tokens, eod_id, is_done)

                # Replace special soft prompt token ids with unk token ids
                if isinstance(model, MegatronGPTPromptLearningModel):
                    pseudo_token_ids_start = model.pseudo_token_ids_start
                    new_tokens[(new_tokens >=
                                pseudo_token_ids_start)] = tokenizer.unk_id
                    tokens[:, :context_length][(
                        tokens[:, :context_length] >=
                        pseudo_token_ids_start)] = tokenizer.unk_id

                # Insert either new predicted or next prompt token
                tokens[:, context_length] = new_tokens

                if output_logits is None:
                    output = F.log_softmax(output[:, :context_length, :], 2)
                    indices = torch.unsqueeze(tokens[:, 1:context_length + 1],
                                              2)
                    output_logits = torch.gather(output, 2, indices).squeeze(2)
                    all_generated_indices = indices[:, :, 0]
                    if all_probs:
                        full_logits = output
                else:
                    output = F.log_softmax(output, 2)
                    indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2)
                    new_output_logits = torch.gather(output, 2,
                                                     indices).squeeze(2)

                    # TODO(rprenger) we're copying output_logits every time.  Should pre-allocate
                    output_logits = torch.cat(
                        [output_logits, new_output_logits], 1)
                    all_generated_indices = torch.cat(
                        [all_generated_indices, indices[:, :, 0]], 1)
                    if all_probs:
                        full_logits = torch.cat([full_logits, output], 1)

                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eod_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                if all_probs:
                    yield tokens, lengths, output_logits, full_logits
                else:
                    yield tokens, lengths, output_logits, None

            else:
                if parallel_state.is_pipeline_first_stage():
                    src = parallel_state.get_pipeline_model_parallel_last_rank(
                    )
                    group = parallel_state.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None, None, None
                else:
                    yield None, None, None, None

                done = torch.cuda.ByteTensor([0])
                src = parallel_state.get_pipeline_model_parallel_last_rank()
                group = parallel_state.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)

            context_length += 1
            counter += 1
            if done:
                break
Ejemplo n.º 9
0
    def complete(self, request: List, positions: List,
                 tokens_to_generate: int):
        """
            Autoregressively invokes language model in the inference mode
        Args:
            request: 
                * tokens: List of "buckets" with unpadded tokens of the same length
                * prompt_tags: List of "buckets" where each bucket contains the prompt_tag strings
                               specifying the prompt tag to use (optional)
            positions: List with initial prompts positions
            tokens_to_generate: int value denoting amount of tokens model should generate

        Returns:	
            response: A python list of tuples
                (text, tokens, log_probs, offsets)
                * text: string, inputted prompt + generated text by model
                * tokens: list of tokens correspond to text
                * log_probs: list of tokens log probabilities
                * offsets: list of tokens start positions in text
                
        """
        results = []
        request_tokens = request["tokens"]

        for idx, tokens in enumerate(request_tokens):

            # For prompt tuned GPT models
            if self.use_soft_prompts:
                prompt_tags = request["prompt_tags"][idx]
            else:
                prompt_tags = None

            logsoftmaxlayer = torch.nn.LogSoftmax(dim=-1)

            for i in range(tokens_to_generate + 1):
                if self.use_soft_prompts:
                    batch_size = len(tokens)
                    full_length = len(tokens[0]) + self.num_prompt_tokens

                    # Get postion ids for text after soft prompt
                    position_ids = torch.arange(start=self.num_prompt_tokens,
                                                end=full_length,
                                                dtype=torch.long,
                                                device=self.device)
                    position_ids = position_ids.unsqueeze(0).expand_as(
                        tokens).clone()

                    # Make attention mask starting with first token in soft prompt
                    attention_mask = torch.tril(
                        torch.ones(
                            (batch_size, full_length, full_length),
                            device=self.device)).view(batch_size, 1,
                                                      full_length, full_length)
                    attention_mask = attention_mask < 0.5

                else:
                    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
                        data=tokens,
                        eod_token=self.tokenizer.eos_id,
                        reset_position_ids=self.cfg.get(
                            'reset_position_ids', False),
                        reset_attention_mask=self.cfg.get(
                            'reset_attention_mask', False),
                        eod_mask_loss=self.cfg.get('eod_mask_loss', False),
                    )

                # No labels during inference. Still need masks to not attend to the right
                output_tensor = self(tokens,
                                     position_ids,
                                     attention_mask,
                                     prompt_tags=prompt_tags,
                                     labels=None)
                output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(
                    output_tensor)
                log_probs, token_ids = torch.max(
                    logsoftmaxlayer(output_tensor), dim=-1)
                reached_eos = token_ids[0, -1].item() == self.tokenizer.eos_id
                tokens = torch.cat(
                    [tokens, torch.unsqueeze(token_ids[:, -1], 1)], dim=1)

            # add to results as (text, tokens, log_probs, offsets)
            for token, prob in zip(tokens, log_probs.tolist()):
                results.append(
                    (self.tokenizer.ids_to_text(token[:-1]),
                     self.tokenizer.ids_to_tokens(token[:-1]), prob, [0]))
        # offsets calculation
        for item in results:
            for index, token in enumerate(item[1]):
                if index != len(item[1]) - 1:
                    item[3].append(len(token) + item[3][-1])
        # returnprompts in order they were inputted
        response = [0 for i in range(len(positions))]
        for item, index in zip(results, positions):
            response[index] = item

        return response
Ejemplo n.º 10
0
 def complete(self, request: Dict):
     """
         Autoregressively invokes language model in the inference mode
     Args:	
         request: Dictionary with the following fields
             * prompt: a string which text the model should complete.
             * tokens_to_generate: how many tokens to generate while doing prompt completion.
             * stop_after_sentence: (default True) whether to stop generation once sentence end is reached.
     Returns:	
         response: A python dictionary with the following fields
             * prompt: original text of the prompt
             * tokenized_prompt: list of (str) tokens from prompt
             * completion: a python dictionary with the following subfields:
                 * tokens: a list of triples (token, token_id, log_prob) comprising completion
                 * stop reason: either 'eos', 'sentence_end' or 'limit' indicating why generation stopped
                 * text: completion text (as a single string)
             
     """
     response = {}
     self.freeze()
     logsoftmaxlayer = torch.nn.LogSoftmax(dim=-1)
     response['tokenized_prompt'] = request['tokenized_prompt']
     tokens = request['tokens']
     # naive greedy slow loop
     # TODO: add option for BeamSearchDecoder
     response['prompt'] = request['prompt']
     response['completion'] = {}
     response['completion']['stop reason'] = 'limit'
     for i in range(request.get("tokens_to_generate", 64)):
         attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
             data=tokens,
             eod_token=self.tokenizer.eos_id,
             reset_position_ids=self.cfg.get('reset_position_ids', False),
             reset_attention_mask=self.cfg.get('reset_attention_mask',
                                               False),
             eod_mask_loss=self.cfg.get('eod_mask_loss', False),
         )
         # No labels during inference. Still need masks to not attend to the right
         output_tensor = self(tokens,
                              position_ids,
                              attention_mask,
                              labels=None)
         output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(
             output_tensor)
         log_probs, token_ids = torch.max(logsoftmaxlayer(output_tensor),
                                          dim=-1)
         reached_eos = token_ids[0, -1].item() == self.tokenizer.eos_id
         tokens = torch.cat([torch.squeeze(tokens), token_ids[:, -1]])
         response['completion']["tokens"] = list(
             zip(self.tokenizer.ids_to_tokens(tokens), tokens.tolist(),
                 log_probs.tolist()[0]))
         completion_text = self.tokenizer.ids_to_text(
             x[1] for x in response['completion']["tokens"])
         if reached_eos:  # Will it actually ever reach that?
             response['completion']['stop reason'] = 'eos'
             break
         elif request.get("stop_after_sentence",
                          True) and completion_text.endswith(
                              ('.', '!', '?')):
             response['completion']['stop reason'] = 'sentence_end'
             break
         tokens = torch.unsqueeze(tokens, 0)
     response['completion']["text"] = self.tokenizer.ids_to_text(
         x[1] for x in response['completion']["tokens"])
     self.unfreeze()
     return response