def init_prompt_from_text(self, prompt_tag, prompt_id, init_token_ids, embeddings): """Add new soft prompt to be tuned. Intialize prompt weights from existing embeddings from specific vocab tokens. """ # Trim or iterate until num_text_tokens matches num_prompt_tokens num_text_tokens = len(init_token_ids) num_prompt_tokens = self.num_prompt_tokens if num_text_tokens > num_prompt_tokens: init_token_ids = init_token_ids[:num_prompt_tokens] elif num_text_tokens < num_prompt_tokens: num_reps = math.ceil(num_prompt_tokens / num_text_tokens) init_token_ids = init_token_ids * num_reps # Set dictionary item keys and datatypes for broadcasting keys = ['text'] datatype = torch.int64 # Broadcast int ids across gpus for tensor parallel init_token_ids = init_token_ids[:num_prompt_tokens] init_token_ids = { 'text': torch.tensor(init_token_ids, dtype=torch.int64) } init_token_ids_b = tensor_parallel.broadcast_data( keys, init_token_ids, datatype) init_token_ids = init_token_ids_b['text'].long() init_position_ids = torch.arange(self.num_prompt_tokens, dtype=torch.long, device=init_token_ids.device) # Use a copy of token embedding weights to initalize the prompt embeddings word_embeddings, position_embeddings = embeddings( init_token_ids, init_position_ids, separate_embeddings=True) word_embeddings = word_embeddings.detach().clone() position_embeddings = position_embeddings.detach().clone() prompt_embeddings = PromptEmbedding( init_from_prompt_text=True, hidden_size=self.hidden_size, num_prompt_tokens=self.num_prompt_tokens, word_embedding_weights=word_embeddings, position_embedding_weights=position_embeddings, ) self.prompt_table[prompt_tag] = prompt_embeddings self.prompt_id_to_tag[prompt_id] = prompt_tag
def generate_fancy_data_labels(sequence_len, batch_size): global data_idx global inds global masks global MANUAL_SEED temps = [] for i in range(batch_size): if inds is None or data_idx >= len(inds): # hack as use of RNG will fall out of sync due to pipelines being different torch.manual_seed(MANUAL_SEED) inds = torch.randperm(effective_length, device="cuda") masks = ( torch.rand( len(inds) // batch_size + 1, batch_size, sequence_len, device="cuda" ) >= MASK_PROB ).long() MANUAL_SEED += 1 print("new epoch", len(inds)) data_idx = 0 print("my start", inds[0:5]) print("masks_checksum:", torch.sum(masks)) if EASY_MODE: data_idx_ = data_idx % EASY_MODE_SIZ else: data_idx_ = data_idx offset = inds[data_idx_] # * SEQUENCE_LEN data_idx += 1 curr = fancy_data[offset : offset + sequence_len].clone().detach() temps.append(curr) temp = torch.stack(temps, dim=0).cuda() mask = masks[data_idx // batch_size] mask_not = torch.logical_not(mask).long() data = mask * temp + mask_not * 124 label = temp if parallel_state.get_tensor_model_parallel_rank() == 0: data_dict = {"text": data, "label": label, "mask_not": mask_not} else: data_dict = None keys = ["text", "label", "mask_not"] dtype = torch.int64 broadcasted_data = tensor_parallel.broadcast_data(keys, data_dict, torch.long) return ( broadcasted_data["text"].long(), broadcasted_data["label"].long(), broadcasted_data["mask_not"], )
def process_batch(self, batch): """Build the batch.""" keys = ['text_enc', 'text_dec', 'labels', 'loss_mask', 'enc_mask', 'dec_mask'] datatype = torch.int64 data = batch data_b = tensor_parallel.broadcast_data(keys, data, datatype) # Unpack. tokens_enc = data_b['text_enc'].long() tokens_dec = data_b['text_dec'].long() labels = data_b['labels'].long() loss_mask = data_b['loss_mask'].float() enc_mask = data_b['enc_mask'] dec_mask = data_b['dec_mask'] return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask
def process_batch(self, batch): """Build the batch.""" # Items and their type. keys = [ 'text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask' ] datatype = torch.int64 data = batch data_b = tensor_parallel.broadcast_data(keys, data, datatype) # Unpack. tokens = data_b['text'].long() types = data_b['types'].long() sentence_order = data_b['is_random'].long() loss_mask = data_b['loss_mask'].float() lm_labels = data_b['labels'].long() padding_mask = data_b['padding_mask'].long() return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def init_prompt_from_text(self, taskname, init_token_ids, word_embeddings, total_virtual_tokens): """Add new virtual prompt to be tuned. Intialize prompt weights from existing embeddings from specific vocab tokens. """ # Trim or iterate until num_text_tokens matches total_virtual_tokens num_text_tokens = len(init_token_ids) if num_text_tokens > total_virtual_tokens: init_token_ids = init_token_ids[:total_virtual_tokens] elif num_text_tokens < total_virtual_tokens: num_reps = math.ceil(total_virtual_tokens / num_text_tokens) init_token_ids = init_token_ids * num_reps # Set dictionary item keys and datatypes for broadcasting keys = ['text'] datatype = torch.int64 # Broadcast int ids across gpus for tensor parallel init_token_ids = init_token_ids[:total_virtual_tokens] init_token_ids = { 'text': torch.tensor(init_token_ids, dtype=torch.int64) } init_token_ids_b = tensor_parallel.broadcast_data( keys, init_token_ids, datatype) init_token_ids = init_token_ids_b['text'].long() # Use a copy of token embedding weights to initalize the prompt embeddings word_embedding_weights = word_embeddings( init_token_ids).detach().clone() self.prompt_table[taskname] = PromptEmbedding( init_from_prompt_text=True, hidden_size=self.hidden_size, total_virtual_tokens=total_virtual_tokens, word_embedding_weights=word_embedding_weights, )
def process_batch(self, batch): # Items and their type. keys = ['text'] datatype = torch.int64 data = batch data_b = tensor_parallel.broadcast_data(keys, data, datatype) # Unpack. tokens_ = data_b['text'].long() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, self.tokenizer.eos_id, self.cfg.data.get('reset_position_ids', False), self.cfg.data.get('reset_attention_mask', False), self.cfg.data.get('eod_mask_loss', False), ) return tokens, labels, loss_mask, attention_mask, position_ids