def create_model(ema=False): if args.type == 1: model = Classifier(num_classes=10) else: model = WideResNet(num_classes=10) model = model.cuda() if ema: for param in model.parameters(): param.detach_() return model
class FullModel(pl.LightningModule): def __init__(self, model_name, vocab, lr, lr_decay, batch_size=64): """ PyTorch Lightning module that creates the overall model. Inputs: model_name - String denoting what encoder class to use. Either 'AWE', 'UniLSTM', 'BiLSTM', or 'BiLSTMMax' vocab - Vocabulary from alignment between SNLI dataset and GloVe vectors lr - Learning rate to use for the optimizer lr_decay - Learning rate decay factor to use each epoch batch_size - Size of the batches. Default is 64 """ super().__init__() self.save_hyperparameters() # create an embedding layer for the vocabulary embeddings self.glove_embeddings = nn.Embedding.from_pretrained(vocab.vectors) # check which encoder model to use if model_name == 'AWE': self.encoder = AWEEncoder() self.classifier = Classifier() elif model_name == 'UniLSTM': self.encoder = UniLSTM() self.classifier = Classifier(input_dim=4 * 2048) elif model_name == 'BiLSTM': self.encoder = BiLSTM() self.classifier = Classifier(input_dim=4 * 2 * 2048) else: self.encoder = BiLSTMMax() self.classifier = Classifier(input_dim=4 * 2 * 2048) # create the loss function self.loss_function = nn.CrossEntropyLoss() # create instance to save the last validation accuracy self.last_val_acc = None def forward(self, sentences): """ The forward function calculates the loss for a given batch of sentences. Inputs: sentences - Batch of sentences with (premise, hypothesis, label) pairs Ouptuts: loss - Cross entropy loss of the predictions accuracy - Accuracy of the predictions """ # get the sentence lengths of the batch lengths_premises = torch.tensor( [x[x != 1].shape[0] for x in sentences.premise], device=self.device) lengths_hypothesis = torch.tensor( [x[x != 1].shape[0] for x in sentences.hypothesis], device=self.device) # pass premises and hypothesis through the embeddings premises = self.glove_embeddings(sentences.premise) hypothesis = self.glove_embeddings(sentences.hypothesis) # forward the premises and hypothesis through the Encoder premises = self.encoder(premises, lengths_premises) hypothesis = self.encoder(hypothesis, lengths_hypothesis) # calculate the difference and multiplication difference = torch.abs(premises - hypothesis) multiplication = premises * hypothesis # create the sentence representations sentence_representations = torch.cat( [premises, hypothesis, difference, multiplication], dim=1) # pass through the classifier predictions = self.classifier(sentence_representations) # calculate the loss and accuracy loss = self.loss_function(predictions, sentences.label) predicted_labels = torch.argmax(predictions, dim=1) accuracy = torch.true_divide( torch.sum(predicted_labels == sentences.label), torch.tensor(sentences.label.shape[0], device=sentences.label.device)) # return the loss and accuracy return loss, accuracy # function that configures the optimizer for the model def configure_optimizers(self): # create optimizer optimizer = torch.optim.SGD([{ 'params': self.encoder.parameters() }, { 'params': self.classifier.parameters() }], lr=self.hparams.lr) # freeze the embeddings self.glove_embeddings.weight.requires_grad = False # create learning rate decay lr_scheduler = { 'scheduler': StepLR(optimizer=optimizer, step_size=1, gamma=self.hparams.lr_decay), 'name': 'learning_rate' } # return the scheduler and optimizer return [optimizer], [lr_scheduler] # function that performs a training step def training_step(self, batch, batch_idx): # forward the batch through the model train_loss, train_acc = self.forward(batch) # log the training loss and accuracy self.log("train_loss", train_loss, on_step=False, on_epoch=True) self.log("train_acc", train_acc, on_step=False, on_epoch=True) # return the training loss return train_loss # function that performs a validation step def validation_step(self, batch, batch_idx): # forward the batch through the model val_loss, val_acc = self.forward(batch) # log the validation loss and accuracy self.log("val_loss", val_loss) self.log("val_acc", val_acc) # save the validation accuracy self.last_val_acc = val_acc # function that performs a test step def test_step(self, batch, batch_idx): # forward the batch through the model test_loss, test_acc = self.forward(batch) # log the test loss and accuracy self.log("test_loss", test_loss) self.log("test_acc", test_acc)