def test_autocast_autodiff(self): def t(t0, t1): o = torch.mm(t0, t1) return o.relu() jit_t = torch.jit.script(t) t0 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_() t1 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_() # run optimization for i in range(5): with torch.autocast("cuda", torch.float16): jit_o = jit_t(t0, t1) jit_o.sum().backward() t0.grad = None t1.grad = None ref_t0 = t0.detach().requires_grad_() ref_t1 = t1.detach().requires_grad_() with torch.autocast("cuda", torch.float16): o = t(ref_t0, ref_t1) jit_o = jit_t(t0, t1) jit_o.sum().backward() o.sum().backward() self.assertEqual(o, jit_o) self.assertEqual(t0.grad, ref_t0.grad) self.assertEqual(t1.grad, ref_t1.grad) self.assertEqual(o.dtype, jit_o.dtype) self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype) self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype)
def test_jit_executor_under_autocast(self): def t(cpu0, cpu1, cuda0, cuda1): cpu_o = torch.mm(cpu0, cpu1) cuda_o = torch.mm(cuda0, cuda1) return cpu_o, cuda_o jit_t = torch.jit.script(t) cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32) cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32) cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32) cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32) with torch.autocast("cpu", torch.bfloat16): with torch.autocast("cuda", torch.float16): self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) with torch.autocast("cpu", torch.bfloat16): self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) with torch.autocast("cuda", torch.float16): self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) # no cast op should be observed when executing outside autocast context self._test_autocast(t, None, cpu0, cpu1, cuda0, cuda1)
def embed_input(self, queries): bz = queries.shape[0] queries_for_embedding = queries.clone() queries_for_embedding[(queries == self.pseudo_token_id)] = self.pad_token_id raw_embeds = self.embeddings(queries_for_embedding) dtype = self.model.model.language_model.encoder.layers[0].dtype if dtype == torch.float32: replace_embeds = self.prompt_encoder(enc_taskname=None) else: with torch.autocast(device_type="cuda", dtype=dtype): replace_embeds = self.prompt_encoder(enc_taskname=None) blocked_indices = queries == self.pseudo_token_id raw_embeds = raw_embeds.clone().type(dtype) # find the index to the psedo-tokens index = blocked_indices.nonzero().reshape((bz, -1, 2))[:, :, 1][:, :, None] _, seq, _ = index.shape _, _, emb = raw_embeds.shape index = index.expand(bz, seq, emb) _, replace_seq, replace_emb = replace_embeds.shape replace_embeds = replace_embeds.expand(bz, replace_seq, replace_emb) # scatter the psedo-token embeddings to the raw embeddings raw_embeds.scatter_(1, index, replace_embeds) # slow version of above scatter logics # for bidx in range(bz): # position = blocked_indices[bidx].nonzero()[:, 0] # for i in range(len(position)): # raw_embeds[bidx, position[i], :] = replace_embeds[bidx, i, :] return raw_embeds
def embed_input(self, enc_input_id: Tensor, enc_taskname_id: Tensor): """ This method will replace the virtual tokens in the enc_input_id with embeddings calculated from `prompt_encoder`. If the `enc_taskname_id` is not None, the computed virtual token embeddings are depenedent on it. The virtual token placeholders has the token_id `self.pseudo_token_id`. params: enc_input_id: the input token ids enc_taskname_id: the NLP task tag token ids returns: the token embedding for the LM model. """ bz = enc_input_id.shape[0] queries_for_embedding = enc_input_id.clone() queries_for_embedding[( enc_input_id == self.pseudo_token_id)] = self.pad_token_id raw_embeds = self.embeddings(queries_for_embedding).clone() if self.cfg.prompt_encoder.task_dependent: enc_taskname = self.embeddings(enc_taskname_id) else: enc_taskname = None if self.float_type == torch.float32: replace_embeds = self.prompt_encoder(enc_taskname=enc_taskname) else: with torch.autocast(device_type="cuda", dtype=self.float_type): replace_embeds = self.prompt_encoder(enc_taskname=enc_taskname) blocked_indices = enc_input_id == self.pseudo_token_id raw_embeds = raw_embeds.clone().type(self.float_type) # find the index to the psedo-tokens index = blocked_indices.nonzero().reshape((bz, -1, 2))[:, :, 1][:, :, None] _, seq, _ = index.shape _, _, emb = raw_embeds.shape index = index.expand(bz, seq, emb) if enc_taskname is None: # taskname none, encoder returens batch 1 # need to expand _, replace_seq, _ = replace_embeds.shape replace_embeds = replace_embeds.expand(bz, replace_seq, emb) # scatter the psedo-token embeddings to the raw embeddings raw_embeds.scatter_(1, index, replace_embeds) # slow version of above scatter logics # for bidx in range(bz): # position = blocked_indices[bidx].nonzero()[:, 0] # for i in range(len(position)): # raw_embeds[bidx, position[i], :] = replace_embeds[bidx, i, :] return raw_embeds
def forward( self, input_ids, position_ids, attention_mask, taskname_ids, labels=None, inference=True, set_inference_key_value_memory=False, inference_max_sequence_len=None, ): """ Special forward method for p-tuning/prompt-tuning pretrained GPT style models. Bypasses the vocab token preprocessing done in the MegatronGPT class. """ # Get embeddings for text tokens and insert virtual token embeddings if inference: input_embeds = self.embed_input_inference(input_ids, taskname_ids) else: input_embeds = self.embed_input_train(input_ids, taskname_ids) position_embeddings = self.frozen_model.model.language_model.embedding.position_embeddings( position_ids) encoder_input = input_embeds + position_embeddings # Call forward on GPT model with preprocessed embeddings if self.float_type == torch.float32: output = self.frozen_model.model( input_ids=None, position_ids=None, encoder_input=encoder_input, attention_mask=attention_mask, labels=labels, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len, ) else: with torch.autocast(device_type="cuda", dtype=self.float_type): output = self.frozen_model.model( input_ids=None, position_ids=None, encoder_input=encoder_input, attention_mask=attention_mask, labels=labels, set_inference_key_value_memory= set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len, ) return output
def forward( self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, layer_past=None, get_key_value=False, ): if self.dtype == torch.float32: return super().forward(hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, layer_past, get_key_value) with torch.autocast(device_type="cuda", dtype=self.dtype): return super().forward(hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, layer_past, get_key_value)
def get_loss(self, batch): tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_taskname = self.process_batch( batch) input_embeds = self.embed_input(tokens_enc, enc_taskname) encoder_position_ids = build_position_ids(tokens_enc) position_embeddings = self.position_embeddings(encoder_position_ids) encoder_input = input_embeds + position_embeddings if self.float_type == torch.float32: output = self.model.enc_dec_model( enc_input_ids=None, enc_attn_mask=enc_mask, dec_input_ids=tokens_dec, dec_attn_mask=dec_mask, token_type_ids=None, labels=labels, enc_hidden_states=None, output_enc_hidden_only=False, enc_input=encoder_input, ) else: with torch.autocast(device_type="cuda", dtype=self.float_type): output = self.model.enc_dec_model( enc_input_ids=None, enc_attn_mask=enc_mask, dec_input_ids=tokens_dec, dec_attn_mask=dec_mask, token_type_ids=None, labels=labels, enc_hidden_states=None, output_enc_hidden_only=False, enc_input=encoder_input, ) tokens_loss = output loss = self.model.loss_func(loss_mask, tokens_loss) self.log('train_loss', loss) return loss, tokens_enc, labels, enc_mask, encoder_input
def forward_eval(self, sentences): encoder_input, new_atten, label_position = self.get_encoder_input(sentences) batch_size, _, seq_len, _ = new_atten.shape # workaround to do auto-cast # get the LM dtype dtype = self.model.model.language_model.encoder.layers[0].dtype if dtype == torch.float32: output = self.model.model( None, None, encoder_input=encoder_input.to(self.device), attention_mask=new_atten.to(self.device) ) else: with torch.autocast(device_type="cuda", dtype=dtype): output = self.model.model( None, None, encoder_input=encoder_input.to(self.device), attention_mask=new_atten.to(self.device) ) logits = output _, returned_pred = self.get_prediction(batch_size, label_position.to(self.device), logits) return returned_pred
def get_loss(self, batch): enc_input = batch['enc_input'] enc_taskname = batch['enc_taskname'] labels = batch['labels'] loss_mask = batch['loss_mask'] enc_query = batch['enc_query'] input_attn_mask = batch['input_attn_mask'] input_attn_mask = input_attn_mask.unsqueeze(1) < 0.5 input_embeds = self.embed_input(enc_input, enc_taskname) encoder_position_ids = build_position_ids(enc_input) position_embeddings = self.model.model.language_model.embedding.position_embeddings( encoder_position_ids) encoder_input = input_embeds + position_embeddings if self.float_type == torch.float32: output = self.model.model( None, None, encoder_input=encoder_input, attention_mask=input_attn_mask, labels=labels, ) else: with torch.autocast(device_type="cuda", dtype=self.float_type): output = self.model.model( None, None, encoder_input=encoder_input, attention_mask=input_attn_mask, labels=labels, ) output_tensor, encoder_hidden_states = output loss = self.loss_func(loss_mask, output_tensor) return loss
def update_model(engine, batch): model.train() images, targets = batch images = list(image.to(device) for image in images) targets = [{ k: v.to(device) for k, v in t.items() if isinstance(v, torch.Tensor) } for t in targets] with torch.autocast(device, enabled=True): loss_dict = model(images, targets) loss = sum(loss for loss in loss_dict.values()) optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() loss_items = {k: v.item() for k, v in loss_dict.items()} loss_items["loss_average"] = loss.item() / 4 return loss_items
def forward(self, sentences, labels): encoder_input, new_atten, label_position = self.get_encoder_input(sentences) batch_size, _, seq_len, _ = new_atten.shape labels_input, label_ids = self.get_label_input(labels, label_position, seq_len) # workaround to do auto-cast # get the LM dtype dtype = self.model.model.language_model.encoder.layers[0].dtype if dtype == torch.float32: output = self.model.model( None, None, encoder_input=encoder_input, attention_mask=new_atten, labels=labels_input ) else: with torch.autocast(device_type="cuda", dtype=dtype): output = self.model.model( None, None, encoder_input=encoder_input, attention_mask=new_atten, labels=labels_input ) loss, logits = output floss = (loss[(labels_input != SMALL_LOGITS)]).mean() _, returned_pred = self.get_prediction(batch_size, label_position, logits) returned_label = self.get_ground_truth_labels(batch_size, label_ids) return floss, returned_pred, returned_label
def decode(self, enc_query, enc_taskname, label_position, num_tokens_to_generate): with torch.no_grad(): predicted_tokens_dec = enc_query label_start = label_position[:, 0].clone() for _ in range(num_tokens_to_generate): attn_mask = make_attention_mask_3d(predicted_tokens_dec, predicted_tokens_dec, self.pad_token_id) attn_mask = attn_mask * make_history_mask_3d( predicted_tokens_dec) attn_mask = attn_mask < 0.5 attn_mask = attn_mask.unsqueeze(1) input_embeds = self.embed_input(predicted_tokens_dec, enc_taskname) encoder_position_ids = build_position_ids(predicted_tokens_dec) position_embeddings = self.model.model.language_model.embedding.position_embeddings( encoder_position_ids) encoder_input = input_embeds + position_embeddings if self.float_type == torch.float32: output = self.model.model( None, None, encoder_input=encoder_input, attention_mask=attn_mask, ) else: with torch.autocast(device_type="cuda", dtype=self.float_type): output = self.model.model( None, None, encoder_input=encoder_input, attention_mask=attn_mask, ) output_tensor = output output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region( output_tensor) # TODO, add logic to use the allowed labels if it is defined log_probs, token_ids = torch.max(nn.functional.log_softmax( output_tensor, dim=-1), dim=-1) new_pred = torch.full_like(token_ids[:, 0:1], self.pad_token_id) predicted_tokens_dec = torch.cat( [predicted_tokens_dec, new_pred], 1) predicted = torch.gather(token_ids, 1, label_start.view(-1, 1)) # need to scatter the token id at the right position label_start += 1 predicted_tokens_dec.scatter_(1, label_start.view(-1, 1), predicted) return predicted_tokens_dec, log_probs
def t(cpu0, cpu1, cuda0, cuda1): with torch.autocast("cpu", torch.bfloat16): with torch.autocast("cuda", torch.float16): cpu_o = torch.mm(cpu0, cpu1) cuda_o = torch.mm(cuda0, cuda1) return cpu_o, cuda_o
def t_autocast_cuda(x, y): # no dtype provided is not currently supported with torch.autocast("cuda"): return torch.mm(x, y)
def t_autocast_cuda(x, y): with torch.autocast("cuda", dtype=torch.half): return torch.mm(x, y)
def t_autocast_cpu(x, y): with torch.autocast("cpu", dtype=torch.bfloat16): return torch.mm(x, y)
def ptune_inference(self, queries: List[Dict], batch_size: int = 1, decode_token_len: int = None) -> List[str]: """ Get prediction for the queries Args: queries: List of data samples without labels batch_size: batch size to use during inference decode_token_len: max number of tokens to generate during inference Returns: all_preds: model predictions """ if decode_token_len is None: decode_token_len = self.decoder_seq_length # store predictions for all queries in a single list all_preds = [] mode = self.training try: # Switch model to evaluation mode self.eval() logging_level = logging.get_verbosity() logging.set_verbosity(logging.WARNING) dataloader_cfg = { "batch_size": batch_size, "num_workers": 3, "pin_memory": False } infer_datalayer = self._setup_infer_dataloader( dataloader_cfg, queries, decode_token_len) for i, batch in enumerate(infer_datalayer): tokens_enc = batch['text_enc'].to(self.device) enc_taskname = batch['enc_taskname'].to(self.device) enc_mask = batch['enc_mask'].to(self.device) input_embeds = self.embed_input(tokens_enc, enc_taskname) encoder_position_ids = build_position_ids(tokens_enc) position_embeddings = self.position_embeddings( encoder_position_ids) encoder_input = input_embeds + position_embeddings # loss, tokens_enc, labels, enc_mask, encoder_input = self.get_loss(batch) if self.float_type == torch.float32: predicted_token_ids, _ = self.model.decode( tokens_enc=tokens_enc, enc_mask=enc_mask, num_tokens_to_generate=decode_token_len, enc_input=encoder_input, ) else: with torch.autocast(device_type="cuda", dtype=self.float_type): predicted_token_ids, _ = self.model.decode( tokens_enc=tokens_enc, enc_mask=enc_mask, num_tokens_to_generate=decode_token_len, enc_input=encoder_input, ) preds = predicted_token_ids.cpu().numpy().tolist() for i, pred in enumerate(preds): if self.tokenizer.eos_id in pred: idx = pred.index(self.tokenizer.eos_id) pred = pred[:idx] pred = [ id for id in pred if id not in self.tokenizer.special_token_to_id.values() ] pred = self.tokenizer.ids_to_text(pred) all_preds.append(pred) finally: # set mode back to its original value self.train(mode=mode) logging.set_verbosity(logging_level) return all_preds