Ejemplo n.º 1
0
    def test_autocast_autodiff(self):
        def t(t0, t1):
            o = torch.mm(t0, t1)
            return o.relu()

        jit_t = torch.jit.script(t)
        t0 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()
        t1 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()

        # run optimization
        for i in range(5):
            with torch.autocast("cuda", torch.float16):
                jit_o = jit_t(t0, t1)
            jit_o.sum().backward()

        t0.grad = None
        t1.grad = None
        ref_t0 = t0.detach().requires_grad_()
        ref_t1 = t1.detach().requires_grad_()

        with torch.autocast("cuda", torch.float16):
            o = t(ref_t0, ref_t1)
            jit_o = jit_t(t0, t1)
        jit_o.sum().backward()
        o.sum().backward()
        self.assertEqual(o, jit_o)
        self.assertEqual(t0.grad, ref_t0.grad)
        self.assertEqual(t1.grad, ref_t1.grad)
        self.assertEqual(o.dtype, jit_o.dtype)
        self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype)
        self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype)
Ejemplo n.º 2
0
    def test_jit_executor_under_autocast(self):

        def t(cpu0, cpu1, cuda0, cuda1):
            cpu_o = torch.mm(cpu0, cpu1)
            cuda_o = torch.mm(cuda0, cuda1)
            return cpu_o, cuda_o

        jit_t = torch.jit.script(t)
        cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
        cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
        cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
        cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32)

        with torch.autocast("cpu", torch.bfloat16):
            with torch.autocast("cuda", torch.float16):
                self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)

        with torch.autocast("cpu", torch.bfloat16):
            self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)

        with torch.autocast("cuda", torch.float16):
            self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)

        # no cast op should be observed when executing outside autocast context
        self._test_autocast(t, None, cpu0, cpu1, cuda0, cuda1)
Ejemplo n.º 3
0
    def embed_input(self, queries):
        bz = queries.shape[0]
        queries_for_embedding = queries.clone()

        queries_for_embedding[(queries == self.pseudo_token_id)] = self.pad_token_id
        raw_embeds = self.embeddings(queries_for_embedding)
        dtype = self.model.model.language_model.encoder.layers[0].dtype
        if dtype == torch.float32:
            replace_embeds = self.prompt_encoder(enc_taskname=None)
        else:
            with torch.autocast(device_type="cuda", dtype=dtype):
                replace_embeds = self.prompt_encoder(enc_taskname=None)

        blocked_indices = queries == self.pseudo_token_id
        raw_embeds = raw_embeds.clone().type(dtype)
        # find the index to the psedo-tokens
        index = blocked_indices.nonzero().reshape((bz, -1, 2))[:, :, 1][:, :, None]

        _, seq, _ = index.shape
        _, _, emb = raw_embeds.shape
        index = index.expand(bz, seq, emb)

        _, replace_seq, replace_emb = replace_embeds.shape
        replace_embeds = replace_embeds.expand(bz, replace_seq, replace_emb)
        # scatter the psedo-token embeddings to the raw embeddings
        raw_embeds.scatter_(1, index, replace_embeds)
        # slow version of above scatter logics
        # for bidx in range(bz):
        #     position = blocked_indices[bidx].nonzero()[:, 0]
        #     for i in range(len(position)):
        #         raw_embeds[bidx, position[i], :] = replace_embeds[bidx, i, :]

        return raw_embeds
Ejemplo n.º 4
0
    def embed_input(self, enc_input_id: Tensor, enc_taskname_id: Tensor):
        """
        This method will replace the virtual tokens in the enc_input_id with
        embeddings calculated from `prompt_encoder`. If the `enc_taskname_id` is
        not None, the computed virtual token embeddings are depenedent on it.
        The virtual token placeholders has the token_id `self.pseudo_token_id`.
        params:
            enc_input_id: the input token ids
            enc_taskname_id: the NLP task tag token ids
        returns:
            the token embedding for the LM model.
        """
        bz = enc_input_id.shape[0]
        queries_for_embedding = enc_input_id.clone()

        queries_for_embedding[(
            enc_input_id == self.pseudo_token_id)] = self.pad_token_id

        raw_embeds = self.embeddings(queries_for_embedding).clone()
        if self.cfg.prompt_encoder.task_dependent:
            enc_taskname = self.embeddings(enc_taskname_id)
        else:
            enc_taskname = None

        if self.float_type == torch.float32:
            replace_embeds = self.prompt_encoder(enc_taskname=enc_taskname)
        else:
            with torch.autocast(device_type="cuda", dtype=self.float_type):
                replace_embeds = self.prompt_encoder(enc_taskname=enc_taskname)

        blocked_indices = enc_input_id == self.pseudo_token_id
        raw_embeds = raw_embeds.clone().type(self.float_type)
        # find the index to the psedo-tokens
        index = blocked_indices.nonzero().reshape((bz, -1, 2))[:, :, 1][:, :,
                                                                        None]

        _, seq, _ = index.shape
        _, _, emb = raw_embeds.shape
        index = index.expand(bz, seq, emb)

        if enc_taskname is None:
            # taskname none, encoder returens batch 1
            # need to expand
            _, replace_seq, _ = replace_embeds.shape
            replace_embeds = replace_embeds.expand(bz, replace_seq, emb)

        # scatter the psedo-token embeddings to the raw embeddings
        raw_embeds.scatter_(1, index, replace_embeds)
        # slow version of above scatter logics
        # for bidx in range(bz):
        #     position = blocked_indices[bidx].nonzero()[:, 0]
        #     for i in range(len(position)):
        #         raw_embeds[bidx, position[i], :] = replace_embeds[bidx, i, :]

        return raw_embeds
Ejemplo n.º 5
0
    def forward(
        self,
        input_ids,
        position_ids,
        attention_mask,
        taskname_ids,
        labels=None,
        inference=True,
        set_inference_key_value_memory=False,
        inference_max_sequence_len=None,
    ):
        """
        Special forward method for p-tuning/prompt-tuning pretrained
        GPT style models. Bypasses the vocab token preprocessing done
        in the MegatronGPT class.
        """
        # Get embeddings for text tokens and insert virtual token embeddings
        if inference:
            input_embeds = self.embed_input_inference(input_ids, taskname_ids)
        else:
            input_embeds = self.embed_input_train(input_ids, taskname_ids)

        position_embeddings = self.frozen_model.model.language_model.embedding.position_embeddings(
            position_ids)
        encoder_input = input_embeds + position_embeddings

        # Call forward on GPT model with preprocessed embeddings
        if self.float_type == torch.float32:
            output = self.frozen_model.model(
                input_ids=None,
                position_ids=None,
                encoder_input=encoder_input,
                attention_mask=attention_mask,
                labels=labels,
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=inference_max_sequence_len,
            )
        else:
            with torch.autocast(device_type="cuda", dtype=self.float_type):
                output = self.frozen_model.model(
                    input_ids=None,
                    position_ids=None,
                    encoder_input=encoder_input,
                    attention_mask=attention_mask,
                    labels=labels,
                    set_inference_key_value_memory=
                    set_inference_key_value_memory,
                    inference_max_sequence_len=inference_max_sequence_len,
                )

        return output
Ejemplo n.º 6
0
    def forward(
        self,
        hidden_states,
        attention_mask,
        encoder_output=None,
        enc_dec_attn_mask=None,
        layer_past=None,
        get_key_value=False,
    ):

        if self.dtype == torch.float32:
            return super().forward(hidden_states, attention_mask,
                                   encoder_output, enc_dec_attn_mask,
                                   layer_past, get_key_value)
        with torch.autocast(device_type="cuda", dtype=self.dtype):
            return super().forward(hidden_states, attention_mask,
                                   encoder_output, enc_dec_attn_mask,
                                   layer_past, get_key_value)
Ejemplo n.º 7
0
    def get_loss(self, batch):
        tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_taskname = self.process_batch(
            batch)
        input_embeds = self.embed_input(tokens_enc, enc_taskname)

        encoder_position_ids = build_position_ids(tokens_enc)

        position_embeddings = self.position_embeddings(encoder_position_ids)

        encoder_input = input_embeds + position_embeddings

        if self.float_type == torch.float32:
            output = self.model.enc_dec_model(
                enc_input_ids=None,
                enc_attn_mask=enc_mask,
                dec_input_ids=tokens_dec,
                dec_attn_mask=dec_mask,
                token_type_ids=None,
                labels=labels,
                enc_hidden_states=None,
                output_enc_hidden_only=False,
                enc_input=encoder_input,
            )
        else:
            with torch.autocast(device_type="cuda", dtype=self.float_type):
                output = self.model.enc_dec_model(
                    enc_input_ids=None,
                    enc_attn_mask=enc_mask,
                    dec_input_ids=tokens_dec,
                    dec_attn_mask=dec_mask,
                    token_type_ids=None,
                    labels=labels,
                    enc_hidden_states=None,
                    output_enc_hidden_only=False,
                    enc_input=encoder_input,
                )

        tokens_loss = output

        loss = self.model.loss_func(loss_mask, tokens_loss)
        self.log('train_loss', loss)

        return loss, tokens_enc, labels, enc_mask, encoder_input
    def forward_eval(self, sentences):
        encoder_input, new_atten, label_position = self.get_encoder_input(sentences)
        batch_size, _, seq_len, _ = new_atten.shape

        # workaround to do auto-cast
        # get the LM dtype
        dtype = self.model.model.language_model.encoder.layers[0].dtype

        if dtype == torch.float32:
            output = self.model.model(
                None, None, encoder_input=encoder_input.to(self.device), attention_mask=new_atten.to(self.device)
            )
        else:
            with torch.autocast(device_type="cuda", dtype=dtype):
                output = self.model.model(
                    None, None, encoder_input=encoder_input.to(self.device), attention_mask=new_atten.to(self.device)
                )
        logits = output

        _, returned_pred = self.get_prediction(batch_size, label_position.to(self.device), logits)
        return returned_pred
Ejemplo n.º 9
0
    def get_loss(self, batch):
        enc_input = batch['enc_input']
        enc_taskname = batch['enc_taskname']
        labels = batch['labels']
        loss_mask = batch['loss_mask']
        enc_query = batch['enc_query']
        input_attn_mask = batch['input_attn_mask']

        input_attn_mask = input_attn_mask.unsqueeze(1) < 0.5

        input_embeds = self.embed_input(enc_input, enc_taskname)

        encoder_position_ids = build_position_ids(enc_input)

        position_embeddings = self.model.model.language_model.embedding.position_embeddings(
            encoder_position_ids)

        encoder_input = input_embeds + position_embeddings

        if self.float_type == torch.float32:
            output = self.model.model(
                None,
                None,
                encoder_input=encoder_input,
                attention_mask=input_attn_mask,
                labels=labels,
            )
        else:
            with torch.autocast(device_type="cuda", dtype=self.float_type):
                output = self.model.model(
                    None,
                    None,
                    encoder_input=encoder_input,
                    attention_mask=input_attn_mask,
                    labels=labels,
                )
        output_tensor, encoder_hidden_states = output
        loss = self.loss_func(loss_mask, output_tensor)
        return loss
Ejemplo n.º 10
0
    def update_model(engine, batch):
        model.train()
        images, targets = batch
        images = list(image.to(device) for image in images)
        targets = [{
            k: v.to(device)
            for k, v in t.items() if isinstance(v, torch.Tensor)
        } for t in targets]

        with torch.autocast(device, enabled=True):
            loss_dict = model(images, targets)
            loss = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loss_items = {k: v.item() for k, v in loss_dict.items()}
        loss_items["loss_average"] = loss.item() / 4

        return loss_items
Ejemplo n.º 11
0
    def forward(self, sentences, labels):
        encoder_input, new_atten, label_position = self.get_encoder_input(sentences)
        batch_size, _, seq_len, _ = new_atten.shape
        labels_input, label_ids = self.get_label_input(labels, label_position, seq_len)
        # workaround to do auto-cast
        # get the LM dtype
        dtype = self.model.model.language_model.encoder.layers[0].dtype

        if dtype == torch.float32:
            output = self.model.model(
                None, None, encoder_input=encoder_input, attention_mask=new_atten, labels=labels_input
            )
        else:
            with torch.autocast(device_type="cuda", dtype=dtype):
                output = self.model.model(
                    None, None, encoder_input=encoder_input, attention_mask=new_atten, labels=labels_input
                )
        loss, logits = output
        floss = (loss[(labels_input != SMALL_LOGITS)]).mean()

        _, returned_pred = self.get_prediction(batch_size, label_position, logits)
        returned_label = self.get_ground_truth_labels(batch_size, label_ids)
        return floss, returned_pred, returned_label
Ejemplo n.º 12
0
    def decode(self, enc_query, enc_taskname, label_position,
               num_tokens_to_generate):
        with torch.no_grad():
            predicted_tokens_dec = enc_query

            label_start = label_position[:, 0].clone()

            for _ in range(num_tokens_to_generate):
                attn_mask = make_attention_mask_3d(predicted_tokens_dec,
                                                   predicted_tokens_dec,
                                                   self.pad_token_id)
                attn_mask = attn_mask * make_history_mask_3d(
                    predicted_tokens_dec)

                attn_mask = attn_mask < 0.5

                attn_mask = attn_mask.unsqueeze(1)

                input_embeds = self.embed_input(predicted_tokens_dec,
                                                enc_taskname)

                encoder_position_ids = build_position_ids(predicted_tokens_dec)
                position_embeddings = self.model.model.language_model.embedding.position_embeddings(
                    encoder_position_ids)

                encoder_input = input_embeds + position_embeddings

                if self.float_type == torch.float32:
                    output = self.model.model(
                        None,
                        None,
                        encoder_input=encoder_input,
                        attention_mask=attn_mask,
                    )
                else:
                    with torch.autocast(device_type="cuda",
                                        dtype=self.float_type):
                        output = self.model.model(
                            None,
                            None,
                            encoder_input=encoder_input,
                            attention_mask=attn_mask,
                        )
                output_tensor = output

                output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(
                    output_tensor)

                # TODO, add logic to use the allowed labels if it is defined
                log_probs, token_ids = torch.max(nn.functional.log_softmax(
                    output_tensor, dim=-1),
                                                 dim=-1)

                new_pred = torch.full_like(token_ids[:, 0:1],
                                           self.pad_token_id)
                predicted_tokens_dec = torch.cat(
                    [predicted_tokens_dec, new_pred], 1)

                predicted = torch.gather(token_ids, 1, label_start.view(-1, 1))

                # need to scatter the token id at the right position
                label_start += 1
                predicted_tokens_dec.scatter_(1, label_start.view(-1, 1),
                                              predicted)

        return predicted_tokens_dec, log_probs
Ejemplo n.º 13
0
 def t(cpu0, cpu1, cuda0, cuda1):
     with torch.autocast("cpu", torch.bfloat16):
         with torch.autocast("cuda", torch.float16):
             cpu_o = torch.mm(cpu0, cpu1)
             cuda_o = torch.mm(cuda0, cuda1)
             return cpu_o, cuda_o
Ejemplo n.º 14
0
 def t_autocast_cuda(x, y):
     # no dtype provided is not currently supported
     with torch.autocast("cuda"):
         return torch.mm(x, y)
Ejemplo n.º 15
0
 def t_autocast_cuda(x, y):
     with torch.autocast("cuda", dtype=torch.half):
         return torch.mm(x, y)
Ejemplo n.º 16
0
 def t_autocast_cpu(x, y):
     with torch.autocast("cpu", dtype=torch.bfloat16):
         return torch.mm(x, y)
Ejemplo n.º 17
0
    def ptune_inference(self,
                        queries: List[Dict],
                        batch_size: int = 1,
                        decode_token_len: int = None) -> List[str]:
        """
        Get prediction for the queries
        Args:
            queries: List of data samples without labels
            batch_size: batch size to use during inference
            decode_token_len: max number of tokens to generate during inference
        Returns:
            all_preds: model predictions
        """
        if decode_token_len is None:
            decode_token_len = self.decoder_seq_length
        # store predictions for all queries in a single list
        all_preds = []
        mode = self.training
        try:
            # Switch model to evaluation mode
            self.eval()
            logging_level = logging.get_verbosity()
            logging.set_verbosity(logging.WARNING)
            dataloader_cfg = {
                "batch_size": batch_size,
                "num_workers": 3,
                "pin_memory": False
            }
            infer_datalayer = self._setup_infer_dataloader(
                dataloader_cfg, queries, decode_token_len)
            for i, batch in enumerate(infer_datalayer):
                tokens_enc = batch['text_enc'].to(self.device)
                enc_taskname = batch['enc_taskname'].to(self.device)
                enc_mask = batch['enc_mask'].to(self.device)

                input_embeds = self.embed_input(tokens_enc, enc_taskname)

                encoder_position_ids = build_position_ids(tokens_enc)

                position_embeddings = self.position_embeddings(
                    encoder_position_ids)

                encoder_input = input_embeds + position_embeddings

                # loss, tokens_enc, labels, enc_mask, encoder_input = self.get_loss(batch)
                if self.float_type == torch.float32:
                    predicted_token_ids, _ = self.model.decode(
                        tokens_enc=tokens_enc,
                        enc_mask=enc_mask,
                        num_tokens_to_generate=decode_token_len,
                        enc_input=encoder_input,
                    )
                else:
                    with torch.autocast(device_type="cuda",
                                        dtype=self.float_type):
                        predicted_token_ids, _ = self.model.decode(
                            tokens_enc=tokens_enc,
                            enc_mask=enc_mask,
                            num_tokens_to_generate=decode_token_len,
                            enc_input=encoder_input,
                        )

                preds = predicted_token_ids.cpu().numpy().tolist()
                for i, pred in enumerate(preds):
                    if self.tokenizer.eos_id in pred:
                        idx = pred.index(self.tokenizer.eos_id)
                        pred = pred[:idx]
                    pred = [
                        id for id in pred if id not in
                        self.tokenizer.special_token_to_id.values()
                    ]
                    pred = self.tokenizer.ids_to_text(pred)
                    all_preds.append(pred)
        finally:
            # set mode back to its original value
            self.train(mode=mode)
            logging.set_verbosity(logging_level)
        return all_preds