def main(): #################################################################### ## Data #################################################################### all_datasets = [] for dataroot in args.dataroot: curr_dataset = BinaryDataset(root_dir=dataroot, binary_format='elf', targets='start', mode='random-chunks', chunk_length=args.sequence_len) all_datasets.append(curr_dataset) # TODO: ConcatDataset. This requires the __len__() to be implemented. dataset = torch.utils.data.ConcatDataset(all_datasets) print("Dataset len() = {0}".format(len(dataset))) dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True) #################################################################### ## Model #################################################################### config = BertConfig( vocab_size=256, hidden_size=args.hidden_size, num_hidden_layers=args.hidden_layers, num_attention_heads=args.num_attn_heads, intermediate_size=args.hidden_size * 4, # BERT originally uses 4x hidden size for this, so copying that. hidden_act='gelu', hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=args.sequence_len, # Sequence length max type_vocab_size=1, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, gradient_checkpointing=False) model = BertForTokenClassification(config=config).cuda() # model = torch.nn.DataParallel(model, dim=0) optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) lossfn = torch.nn.CrossEntropyLoss() print("Beginning training") for epoch in range(args.epochs): train_loss, train_acc = train(model, lossfn, optimizer, dataloader, epoch) print( f"Train Loss: {train_loss} | Test Loss: {test_loss} | Test Acc: {test_acc}" )
'/home/brian/Downloads/all_samples_6-mer_train.txt') seq_ids, masks, labels = tokenize_and_pad_samples(genes, labels) print(seq_ids[0]) print(len(seq_ids)) print("Finished making data") batch_size = 1 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = BertForTokenClassification( BertConfig.from_json_file( '/home/brian/attentive_splice/bert_configuration_all_hex.json')) model.resize_token_embeddings(4099) model.to(device) optimizer = Adam(model.parameters(), lr=1e-3) #lr=3e-5) class_weights = torch.tensor(np.array([1.0, 165.0])).float().cuda() loss = CrossEntropyLoss(weight=class_weights) last_i = 0 def load_model_from_saved(): with open('/home/brian/bert_last_i.txt', 'r') as last_i_file: i = last_i_file.read() last_i = int(i) model.load_state_dict(torch.load("/home/brian/bert_splice_weights.pt")) def save_weights(): print("Saving weights") path = "/home/brian/bert_weights_6mer.pt"
class TorchBertSequenceTagger(TorchModel): """BERT-based model on PyTorch for text tagging. It predicts a label for every token (not subtoken) in the text. You can use it for sequence labeling tasks, such as morphological tagging or named entity recognition. Args: n_tags: number of distinct tags pretrained_bert: pretrained Bert checkpoint path or key title (e.g. "bert-base-uncased") return_probas: set this to `True` if you need the probabilities instead of raw answers bert_config_file: path to Bert configuration file, or None, if `pretrained_bert` is a string name attention_probs_keep_prob: keep_prob for Bert self-attention layers hidden_keep_prob: keep_prob for Bert hidden layers optimizer: optimizer name from `torch.optim` optimizer_parameters: dictionary with optimizer's parameters, e.g. {'lr': 0.1, 'weight_decay': 0.001, 'momentum': 0.9} learning_rate_drop_patience: how many validations with no improvements to wait learning_rate_drop_div: the divider of the learning rate after `learning_rate_drop_patience` unsuccessful validations load_before_drop: whether to load best model before dropping learning rate or not clip_norm: clip gradients by norm min_learning_rate: min value of learning rate if learning rate decay is used """ def __init__(self, n_tags: int, pretrained_bert: str, bert_config_file: Optional[str] = None, return_probas: bool = False, attention_probs_keep_prob: Optional[float] = None, hidden_keep_prob: Optional[float] = None, optimizer: str = "AdamW", optimizer_parameters: dict = {"lr": 1e-3, "weight_decay": 1e-6}, learning_rate_drop_patience: int = 20, learning_rate_drop_div: float = 2.0, load_before_drop: bool = True, clip_norm: Optional[float] = None, min_learning_rate: float = 1e-07, **kwargs) -> None: self.n_classes = n_tags self.return_probas = return_probas self.attention_probs_keep_prob = attention_probs_keep_prob self.hidden_keep_prob = hidden_keep_prob self.clip_norm = clip_norm self.pretrained_bert = pretrained_bert self.bert_config_file = bert_config_file super().__init__(optimizer=optimizer, optimizer_parameters=optimizer_parameters, learning_rate_drop_patience=learning_rate_drop_patience, learning_rate_drop_div=learning_rate_drop_div, load_before_drop=load_before_drop, min_learning_rate=min_learning_rate, **kwargs) def train_on_batch(self, input_ids: Union[List[List[int]], np.ndarray], input_masks: Union[List[List[int]], np.ndarray], y_masks: Union[List[List[int]], np.ndarray], y: List[List[int]], *args, **kwargs) -> Dict[str, float]: """ Args: input_ids: batch of indices of subwords input_masks: batch of masks which determine what should be attended args: arguments passed to _build_feed_dict and corresponding to additional input and output tensors of the derived class. kwargs: keyword arguments passed to _build_feed_dict and corresponding to additional input and output tensors of the derived class. Returns: dict with fields 'loss', 'head_learning_rate', and 'bert_learning_rate' """ b_input_ids = torch.from_numpy(input_ids).to(self.device) b_input_masks = torch.from_numpy(input_masks).to(self.device) subtoken_labels = [token_labels_to_subtoken_labels(y_el, y_mask, input_mask) for y_el, y_mask, input_mask in zip(y, y_masks, input_masks)] b_labels = torch.from_numpy(np.array(subtoken_labels)).to(torch.int64).to(self.device) self.optimizer.zero_grad() loss, logits = self.model(input_ids=b_input_ids, token_type_ids=None, attention_mask=b_input_masks, labels=b_labels) loss.backward() # Clip the norm of the gradients to 1.0. # This is to help prevent the "exploding gradients" problem. if self.clip_norm: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_norm) self.optimizer.step() if self.lr_scheduler is not None: self.lr_scheduler.step() return {'loss': loss.item()} def __call__(self, input_ids: Union[List[List[int]], np.ndarray], input_masks: Union[List[List[int]], np.ndarray], y_masks: Union[List[List[int]], np.ndarray]) -> Union[List[List[int]], List[np.ndarray]]: """ Predicts tag indices for a given subword tokens batch Args: input_ids: indices of the subwords input_masks: mask that determines where to attend and where not to y_masks: mask which determines the first subword units in the the word Returns: Label indices or class probabilities for each token (not subtoken) """ b_input_ids = torch.from_numpy(input_ids).to(self.device) b_input_masks = torch.from_numpy(input_masks).to(self.device) with torch.no_grad(): # Forward pass, calculate logit predictions logits = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_masks) # Move logits and labels to CPU and to numpy arrays logits = token_from_subtoken(logits[0].detach().cpu(), torch.from_numpy(y_masks)) if self.return_probas: pred = torch.nn.functional.softmax(logits, dim=-1) pred = pred.detach().cpu().numpy() else: logits = logits.detach().cpu().numpy() pred = np.argmax(logits, axis=-1) seq_lengths = np.sum(y_masks, axis=1) pred = [p[:l] for l, p in zip(seq_lengths, pred)] return pred @overrides def load(self, fname=None): if fname is not None: self.load_path = fname if self.pretrained_bert and not Path(self.pretrained_bert).is_file(): self.model = BertForTokenClassification.from_pretrained( self.pretrained_bert, num_labels=self.n_classes, output_attentions=False, output_hidden_states=False) elif self.bert_config_file and Path(self.bert_config_file).is_file(): self.bert_config = BertConfig.from_json_file(str(expand_path(self.bert_config_file))) if self.attention_probs_keep_prob is not None: self.bert_config.attention_probs_dropout_prob = 1.0 - self.attention_probs_keep_prob if self.hidden_keep_prob is not None: self.bert_config.hidden_dropout_prob = 1.0 - self.hidden_keep_prob self.model = BertForTokenClassification(config=self.bert_config) else: raise ConfigError("No pre-trained BERT model is given.") self.model.to(self.device) self.optimizer = getattr(torch.optim, self.optimizer_name)( self.model.parameters(), **self.optimizer_parameters) if self.lr_scheduler_name is not None: self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)( self.optimizer, **self.lr_scheduler_parameters) if self.load_path: log.info(f"Load path {self.load_path} is given.") if isinstance(self.load_path, Path) and not self.load_path.parent.is_dir(): raise ConfigError("Provided load path is incorrect!") weights_path = Path(self.load_path.resolve()) weights_path = weights_path.with_suffix(f".pth.tar") if weights_path.exists(): log.info(f"Load path {weights_path} exists.") log.info(f"Initializing `{self.__class__.__name__}` from saved.") # now load the weights, optimizer from saved log.info(f"Loading weights from {weights_path}.") checkpoint = torch.load(weights_path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) self.epochs_done = checkpoint.get("epochs_done", 0) else: log.info(f"Init from scratch. Load path {weights_path} does not exist.")