def trim_seq2seq_batch(batch, pad_token_id): target_ids = trim_batch(batch["target_ids"], pad_token_id) source_ids, source_mask = trim_batch( batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"]) return source_ids, source_mask, target_ids
def trim_seq2seq_batch(batch, pad_token_id, test=False): # Remove columns that are populated exclusively by pad_token_id # This ensures that each batch is padded only uptil the "max sequence length" # https://github.com/huggingface/transformers/blob/1e51bb717c04ca4b01a05a7a548e6b550be38628/src/transformers/tokenization_utils.py source_ids, source_mask = trim_batch( batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"]) if test: return source_ids, source_mask, None y = trim_batch(batch["target_ids"], pad_token_id) return source_ids, source_mask, y
def collate_fn(self, batch): rob_emb = [] rob_emb_new = [] max_seq_rob = -1 input_ids = torch.stack([x["source_ids"] for x in batch]) masks = torch.stack([x["source_mask"] for x in batch]) target_ids = torch.stack([x["target_ids"] for x in batch]) roberta_embeddings = torch.stack([x["roberta"] for x in batch])#torch.stack(rob_emb_new) pad_token_id = self.tokenizer.pad_token_id y = trim_batch(target_ids, pad_token_id) source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": y,"roberta_embeddings": roberta_embeddings}
def collate_fn(self, batch): input_ids = torch.stack([x["source_ids"] for x in batch]) masks = torch.stack([x["source_mask"] for x in batch]) target_ids = torch.stack([x["target_ids"] for x in batch]) pad_token_id = self.tokenizer.pad_token_id y = trim_batch(target_ids, pad_token_id) source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) return { "source_ids": source_ids, "source_mask": source_mask, "target_ids": y }
def collate_fn(self, batch): """ The tensors are stacked together as they are yielded. Collate function is applied to the output of a DataLoader as it is yielded. """ input_ids = torch.stack([x["source_ids"] for x in batch]) # BS x SL masks = torch.stack([x["source_mask"] for x in batch]) # BS x SL pad_token_id = self.tokenizer.pad_token_id source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) if self.type_path == "test": return {"source_ids": source_ids, "source_mask": source_mask} target_ids = torch.stack([x["target_ids"] for x in batch]) # BS x SL # Remove columns that are purely padding y = trim_batch(target_ids, pad_token_id) # Return dictionary containing tensors return { "source_ids": source_ids, "source_mask": source_mask, "target_ids": y }