Example #1
0
    def init_prompt_from_text(self, prompt_tag, prompt_id, init_token_ids,
                              embeddings):
        """Add new soft prompt to be tuned.
           Intialize prompt weights from existing embeddings from specific vocab tokens.

        """
        # Trim or iterate until num_text_tokens matches num_prompt_tokens
        num_text_tokens = len(init_token_ids)
        num_prompt_tokens = self.num_prompt_tokens

        if num_text_tokens > num_prompt_tokens:
            init_token_ids = init_token_ids[:num_prompt_tokens]
        elif num_text_tokens < num_prompt_tokens:
            num_reps = math.ceil(num_prompt_tokens / num_text_tokens)
            init_token_ids = init_token_ids * num_reps

        # Set dictionary item keys and datatypes for broadcasting
        keys = ['text']
        datatype = torch.int64

        # Broadcast int ids across gpus for tensor parallel
        init_token_ids = init_token_ids[:num_prompt_tokens]
        init_token_ids = {
            'text': torch.tensor(init_token_ids, dtype=torch.int64)
        }
        init_token_ids_b = tensor_parallel.broadcast_data(
            keys, init_token_ids, datatype)
        init_token_ids = init_token_ids_b['text'].long()
        init_position_ids = torch.arange(self.num_prompt_tokens,
                                         dtype=torch.long,
                                         device=init_token_ids.device)

        # Use a copy of token embedding weights to initalize the prompt embeddings
        word_embeddings, position_embeddings = embeddings(
            init_token_ids, init_position_ids, separate_embeddings=True)

        word_embeddings = word_embeddings.detach().clone()
        position_embeddings = position_embeddings.detach().clone()

        prompt_embeddings = PromptEmbedding(
            init_from_prompt_text=True,
            hidden_size=self.hidden_size,
            num_prompt_tokens=self.num_prompt_tokens,
            word_embedding_weights=word_embeddings,
            position_embedding_weights=position_embeddings,
        )

        self.prompt_table[prompt_tag] = prompt_embeddings
        self.prompt_id_to_tag[prompt_id] = prompt_tag
Example #2
0
def generate_fancy_data_labels(sequence_len, batch_size):
    global data_idx
    global inds
    global masks
    global MANUAL_SEED
    temps = []
    for i in range(batch_size):
        if inds is None or data_idx >= len(inds):
            # hack as use of RNG will fall out of sync due to pipelines being different
            torch.manual_seed(MANUAL_SEED)
            inds = torch.randperm(effective_length, device="cuda")
            masks = (
                torch.rand(
                    len(inds) // batch_size + 1, batch_size, sequence_len, device="cuda"
                )
                >= MASK_PROB
            ).long()
            MANUAL_SEED += 1
            print("new epoch", len(inds))
            data_idx = 0
            print("my start", inds[0:5])
            print("masks_checksum:", torch.sum(masks))
        if EASY_MODE:
            data_idx_ = data_idx % EASY_MODE_SIZ
        else:
            data_idx_ = data_idx
        offset = inds[data_idx_]  # * SEQUENCE_LEN
        data_idx += 1

        curr = fancy_data[offset : offset + sequence_len].clone().detach()
        temps.append(curr)
    temp = torch.stack(temps, dim=0).cuda()
    mask = masks[data_idx // batch_size]
    mask_not = torch.logical_not(mask).long()
    data = mask * temp + mask_not * 124
    label = temp
    if parallel_state.get_tensor_model_parallel_rank() == 0:
        data_dict = {"text": data, "label": label, "mask_not": mask_not}
    else:
        data_dict = None
    keys = ["text", "label", "mask_not"]
    dtype = torch.int64
    broadcasted_data = tensor_parallel.broadcast_data(keys, data_dict, torch.long)
    return (
        broadcasted_data["text"].long(),
        broadcasted_data["label"].long(),
        broadcasted_data["mask_not"],
    )
    def process_batch(self, batch):
        """Build the batch."""

        keys = ['text_enc', 'text_dec', 'labels', 'loss_mask', 'enc_mask', 'dec_mask']
        datatype = torch.int64
        data = batch
        data_b = tensor_parallel.broadcast_data(keys, data, datatype)

        # Unpack.
        tokens_enc = data_b['text_enc'].long()
        tokens_dec = data_b['text_dec'].long()
        labels = data_b['labels'].long()
        loss_mask = data_b['loss_mask'].float()

        enc_mask = data_b['enc_mask']
        dec_mask = data_b['dec_mask']

        return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask
Example #4
0
    def process_batch(self, batch):
        """Build the batch."""
        # Items and their type.
        keys = [
            'text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'
        ]
        datatype = torch.int64

        data = batch
        data_b = tensor_parallel.broadcast_data(keys, data, datatype)

        # Unpack.
        tokens = data_b['text'].long()
        types = data_b['types'].long()
        sentence_order = data_b['is_random'].long()
        loss_mask = data_b['loss_mask'].float()
        lm_labels = data_b['labels'].long()
        padding_mask = data_b['padding_mask'].long()
        return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
Example #5
0
    def init_prompt_from_text(self, taskname, init_token_ids, word_embeddings,
                              total_virtual_tokens):
        """Add new virtual prompt to be tuned.
           Intialize prompt weights from existing embeddings from specific vocab tokens.

        """
        # Trim or iterate until num_text_tokens matches total_virtual_tokens
        num_text_tokens = len(init_token_ids)

        if num_text_tokens > total_virtual_tokens:
            init_token_ids = init_token_ids[:total_virtual_tokens]
        elif num_text_tokens < total_virtual_tokens:
            num_reps = math.ceil(total_virtual_tokens / num_text_tokens)
            init_token_ids = init_token_ids * num_reps

        # Set dictionary item keys and datatypes for broadcasting
        keys = ['text']
        datatype = torch.int64

        # Broadcast int ids across gpus for tensor parallel
        init_token_ids = init_token_ids[:total_virtual_tokens]
        init_token_ids = {
            'text': torch.tensor(init_token_ids, dtype=torch.int64)
        }
        init_token_ids_b = tensor_parallel.broadcast_data(
            keys, init_token_ids, datatype)
        init_token_ids = init_token_ids_b['text'].long()

        # Use a copy of token embedding weights to initalize the prompt embeddings
        word_embedding_weights = word_embeddings(
            init_token_ids).detach().clone()

        self.prompt_table[taskname] = PromptEmbedding(
            init_from_prompt_text=True,
            hidden_size=self.hidden_size,
            total_virtual_tokens=total_virtual_tokens,
            word_embedding_weights=word_embedding_weights,
        )
Example #6
0
    def process_batch(self, batch):

        # Items and their type.
        keys = ['text']
        datatype = torch.int64

        data = batch
        data_b = tensor_parallel.broadcast_data(keys, data, datatype)

        # Unpack.
        tokens_ = data_b['text'].long()
        labels = tokens_[:, 1:].contiguous()
        tokens = tokens_[:, :-1].contiguous()

        # Get the masks and postition ids.
        attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
            tokens,
            self.tokenizer.eos_id,
            self.cfg.data.get('reset_position_ids', False),
            self.cfg.data.get('reset_attention_mask', False),
            self.cfg.data.get('eod_mask_loss', False),
        )

        return tokens, labels, loss_mask, attention_mask, position_ids