def __getitem__(self, index): doc_index = index input_ids = set_tensor_device(self.input_ids[doc_index], device=self.device) attention_mask = set_tensor_device(self.attention_mask[doc_index], device=self.device) if self.y is None: return (input_ids, attention_mask) #y = None else: doc_labels, sent_labels = self.y[index] doc_tensors = tensorfy_doc_labels_multi( \ doc_labels = doc_labels, device = self.device) sent_tensors = tensorfy_sent_labels_multi( \ sent_labels = sent_labels, batch_size = self.max_sent_count, device = self.device) return (input_ids, attention_mask, doc_tensors, sent_tensors)
def span_indices_tensor(self, device=None): tensor = torch.LongTensor(self.span_indices) tensor = set_tensor_device(tensor, device) return tensor
def document_tensor(self, labels, device=None): tensor_dict = OrderedDict([(type_, label_tensor(label)) \ for type_, label in labels.items()]) for k, v in tensor_dict.items(): tensor_dict[k] = set_tensor_device(tensor_dict[k], device) return tensor_dict
def tensorfy_doc_labels_multi(doc_labels, device=None): tensor_dict = OrderedDict() for doc_label, v in doc_labels.items(): v = torch.LongTensor([v]).squeeze() v = set_tensor_device(v, device=device) tensor_dict[doc_label] = v return tensor_dict
def relation_tensor(self, relations, batch=False, device=None): if batch: tensor_dict = relation_tensor_batch(relations, self.span_count, self.batch_size) else: tensor_dict = relation_tensor_seq(relations, self.span_count) for k, v in tensor_dict.items(): tensor_dict[k] = set_tensor_device(v, device) return tensor_dict
def entity_tensor(self, entities, batch=False, device=None): if batch: tensor_dict = entity_tensor_batch(entities, self.span_count, self.batch_size) else: tensor_dict = entity_tensor_seq(entities, self.span_count) for k, v in tensor_dict.items(): tensor_dict[k] = set_tensor_device(v, device) return tensor_dict
def tensorfy_sent_labels_multi(sent_labels, batch_size, device=None): tensor_dict = OrderedDict() for doc_label, subtype_combos in sent_labels.items(): tensor_dict[doc_label] = OrderedDict() for combo, v in subtype_combos.items(): v = pad1D(torch.LongTensor(v), batch_size) v = set_tensor_device(v, device=device) tensor_dict[doc_label][combo] = v return tensor_dict
def span_mask_tensor(self, seq_length, device=None): if torch.is_tensor(seq_length): seq_length = seq_length.tolist() if isinstance(seq_length, (int, float)): tensor = torch.LongTensor(self.span_mask[seq_length]) else: tensor = torch.LongTensor([self.span_mask[n] for n in seq_length]) tensor = set_tensor_device(tensor, device) return tensor
def __getitem__(self, index): device = self.device doc_index, sent_index = self.index2doc_sent[index] seq_length =self.seq_length[doc_index][sent_index] seq_length = set_tensor_device(seq_length, device) # (sentence length, embedding dimension) seq_tensor = self.X[doc_index][sent_index] seq_tensor = set_tensor_device(seq_tensor, device) # (sentence length) seq_mask = self.mask[doc_index][sent_index] seq_mask = set_tensor_device(seq_mask, device) span_indices = self.span_mapper.span_indices_tensor(device=device) span_mask = self.span_mapper.span_mask_tensor(seq_length, device=device) if self.y is None: return (index, seq_tensor, seq_mask, span_indices, span_mask) #y = None else: y_ = self.y[doc_index][sent_index] y = OrderedDict() y['span_indices'] = span_indices y['span_mask'] = span_mask #y['seq_mask'] = seq_mask y["span_labels"] = self.span_mapper.entity_tensor(y_["span_labels"], batch=False, device=device) y["role_labels"] = self.span_mapper.relation_tensor(y_["role_labels"], batch=False, device=device) return (index, seq_tensor, seq_mask, span_indices, span_mask, y)
def encode_document(encoded_dict, model, \ word_pieces_keep = None, device = None, detach = True, move_to_cpu = True, max_length = None, verbose = False, batch_size = 100): input_ids = encoded_dict[INPUT_IDS] mask = encoded_dict[ATTENTION_MASK] batches = ceil(len(mask)/batch_size) x_batches = [] for i in range(batches): start = i*batch_size end = (i+1)*batch_size input_ids_batch = set_tensor_device(input_ids[start:end], device) mask_batch = set_tensor_device(mask[start:end], device) x = model( \ input_ids = input_ids_batch, token_type_ids = None, attention_mask = mask_batch)[0] if move_to_cpu: x = x.cpu() if detach: x = x.detach() x_batches.append(x) x = torch.cat(x_batches, dim=0) assert len(x) == len(mask) if word_pieces_keep is not None: assert len(word_pieces_keep) == len(x) x_temp = torch.zeros_like(x) mask_temp = torch.zeros_like(mask) for i, wp_keep in enumerate(word_pieces_keep): for j, target in enumerate(wp_keep): if isinstance(target, (list, tuple)): a, b = tuple(target) x_temp[i, j, :] = x[i, a:b, :].mean(dim=0) mask_temp[i, j] = mask[i, a] else: x_temp[i, j, :] = x[i, target, :] mask_temp[i, j] = mask[i, target] x = x_temp mask = mask_temp if max_length is not None: if word_pieces_keep is not None: assert x[:,max_length:,:].sum().tolist() == 0 assert mask[:,max_length:].sum().tolist() == 0 x = x[:,:max_length,:] mask = mask[:,:max_length] return (x, mask)
def __getitem__(self, index): doc_index = index #seq_length = pad1D(self.seq_length[index], self.max_sent_count) #seq_length = set_tensor_device(seq_length, self.device) #seq_length = self.seq_length[index] # (sentence count, sentence length, embedding dimension) seq_tensor = pad3D(self.X[index], self.max_sent_count) seq_tensor = set_tensor_device(seq_tensor, self.device) # (sentence count, sentence length) seq_mask = pad2D(self.mask[index], self.max_sent_count) seq_mask = set_tensor_device(seq_mask, self.device) # (span_count, 2) #span_indices = self.span_mapper.span_indices_tensor(device=self.device) # (sentence_count, span_count, 2) #span_indices = span_indices.repeat(self.max_sent_count, 1, 1) #span_mask = self.span_mapper.span_mask_tensor(seq_length, device=self.device) if self.y is None: #return (doc_index, seq_tensor, seq_mask, span_indices, span_mask) #return (doc_index, seq_tensor, seq_mask, None, None) return (doc_index, seq_tensor, seq_mask) #y = None else: y_ = self.y[index] y = OrderedDict() #y['seq_length'] = seq_length #y['span_indices'] = span_indices #y['span_mask'] = span_mask #y['seq_mask'] = seq_mask y["doc_labels"] = tensorfy_doc_labels_multi( \ doc_labels = y_["doc_labels"], device = self.device) y["sent_labels"] = tensorfy_sent_labels_multi( \ sent_labels = y_["sent_labels"], batch_size = self.max_sent_count, device = self.device) #y["span_labels"] = self.span_mapper.entity_tensor(y_["span_labels"], batch=True, device=self.device) #y["role_labels"] = self.span_mapper.relation_tensor(y_["role_labels"], batch=True, device=self.device) #return (doc_index, seq_tensor, seq_mask, span_indices, span_mask, y) #return (doc_index, seq_tensor, seq_mask, None, None, y) return (doc_index, seq_tensor, seq_mask, y)
def encode_documents(input_ids, mask, \ pretrained=PRETRAINED, device=None, train=False): logging.info("Embedding using AutoModel") model = AutoModel.from_pretrained(pretrained) if train: model.train() else: model.eval() set_model_device(model, device) X = [] masks = [] pbar = tqdm(total=len(input_ids)) assert len(input_ids) == len(mask) for i, (ids, msk) in enumerate(zip(input_ids, mask)): ids = set_tensor_device(ids, device) msk = set_tensor_device(msk, device) x = model( \ ids, token_type_ids=None, attention_mask=msk)[0] x = x.cpu().detach() X.append(x) if i == 1: logging.info("Encode documents") #logging.info("-"*80) #logging.info("") #logging.info('IDs: {}\n{}'.format(ids.shape, ids)) logging.info('IDs: {}'.format(ids.shape)) #logging.info("") #logging.info('Mask: {}\n{}'.format(msk.shape, msk)) logging.info('Mask: {}'.format(msk.shape)) #logging.info("") #logging.info('X: {}\n{}'.format(x.shape, x)) logging.info('X: {}'.format(x.shape)) logging.info('') #logging.info("") #logging.info("-"*80) pbar.update() pbar.close() logging.info("") logging.info('Document count: {}'.format(len(X))) logging.info("") return X