class SimpleModel(LightningModule): def __init__(self, vocab_size, embedding_dim=32): super().__init__() self.embeddings_layer = nn.Embedding(vocab_size, embedding_dim) self.loss = nn.BCEWithLogitsLoss() self.valid_accuracy = Accuracy() self.test_accuracy = Accuracy() def forward(self, inputs, labels): raise NotImplementedError("forward not implemented") def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return [optimizer] def training_step(self, batch, _): inputs, labels = batch loss, logits = self(inputs, labels) return loss def validation_step(self, batch, _): inputs, labels = batch val_loss, logits = self(inputs, labels) if torch.max(labels) == 1: pred = torch.sigmoid(logits) else: pred = torch.softmax(logits, 1) self.valid_accuracy.update(pred, labels.long()) self.log("val_loss", val_loss, prog_bar=True) self.log("val_acc", self.valid_accuracy) def validation_epoch_end(self, outs): self.log("val_acc_epoch", self.valid_accuracy.compute(), prog_bar=True) def test_step(self, batch, _): inputs, labels = batch test_loss, logits = self(inputs, labels) if torch.max(labels) == 1: pred = torch.sigmoid(logits) else: pred = torch.softmax(logits, 1) self.test_accuracy.update(pred, labels.long()) self.log("test_loss", test_loss, prog_bar=True) self.log("test_acc", self.test_accuracy) def test_epoch_end(self, outs): self.log("test_acc_epoch", self.test_accuracy.compute(), prog_bar=True)
class Densenet121Lightning(BaseParticipantModel, pl.LightningModule): def __init__(self, num_classes, *args, weights=None, pretrain=True, **kwargs): model = torchvision.models.densenet121(pretrained=pretrain) model.classifier = Linear(in_features=1024, out_features=num_classes, bias=True) super().__init__(*args, model=model, **kwargs) self.model = model self.accuracy = Accuracy() self.train_accuracy = Accuracy() self.criterion = CrossEntropyLoss(weight=weights) def training_step(self, train_batch, batch_idx): x, y = train_batch y = y.long() logits = self.model(x) loss = self.criterion(logits, y) preds = torch.argmax(logits, dim=1) self.log('train/acc/{}'.format(self.participant_name), self.train_accuracy(preds, y)) self.log('train/loss/{}'.format(self.participant_name), loss.item()) return loss def test_step(self, test_batch, batch_idx): x, y = test_batch y = y.long() logits = self.model(x) loss = self.criterion(logits, y) preds = torch.argmax(logits, dim=1) self.accuracy.update(preds, y) GlobalConfusionMatrix().update(preds, y) return {'loss': loss} def test_epoch_end(self, outputs: List[Any]) -> None: loss_list = [o['loss'] for o in outputs] loss = torch.stack(loss_list) self.log(f'sample_num', self.accuracy.total.item()) self.log(f'test/acc/{self.participant_name}', self.accuracy.compute()) self.log(f'test/loss/{self.participant_name}', loss.mean().item())
class CNNLightning(BaseParticipantModel, pl.LightningModule): def __init__(self, only_digits=False, input_channels=1, *args, **kwargs): model = CNN_OriginalFedAvg(only_digits=only_digits, input_channels=input_channels) super().__init__(*args, model=model, **kwargs) self.model = model # self.model.apply(init_weights) self.accuracy = Accuracy() self.train_accuracy = Accuracy() def training_step(self, train_batch, batch_idx): x, y = train_batch y = y.long() logits = self.model(x) loss = F.cross_entropy(logits, y) preds = torch.argmax(logits, dim=1) self.log(f'train/acc/{self.participant_name}', self.train_accuracy(preds, y).item()) self.log(f'train/loss/{self.participant_name}', loss.mean().item()) return loss def test_step(self, test_batch, batch_idx): x, y = test_batch y = y.long() logits = self.model(x) loss = F.cross_entropy(logits, y) preds = torch.argmax(logits, dim=1) self.accuracy.update(preds, y) return {'loss': loss} def test_epoch_end( self, outputs: List[Any] ) -> None: loss_list = [o['loss'] for o in outputs] loss = torch.stack(loss_list) self.log(f'sample_num', self.accuracy.total.item()) self.log(f'test/acc/{self.participant_name}', self.accuracy.compute()) self.log(f'test/loss/{self.participant_name}', loss.mean().item())
class FastTextLSTMModel(LightningModule): """ Run LSTM over tokens FastText embeddings and take final hidden state, add linear projection and dropout """ def __init__(self, ft_embedding_dim, hidden_dim=64): super().__init__() self.lstm_layer = nn.LSTM(ft_embedding_dim, hidden_dim, batch_first=True, bidirectional=True) self.dropout_layer = nn.Dropout(0.2) self.out_layer = nn.Linear(hidden_dim * 2, 1) self.loss = nn.BCEWithLogitsLoss() self.valid_accuracy = Accuracy() self.test_accuracy = Accuracy() def forward(self, embeddings, labels): """ Forward pass :param embeddings: (batch_size, max_tokens_in_text, ft_embedding_dim) text -> ["hello", ",", "world", ..] -> [9, 56, 72, ..] + padding or cutting to max sequence length :param labels: (batch_size, 1) :return: loss and logits """ batch_size = embeddings.size(0) output, (final_hidden_state, final_cell_state) = self.lstm_layer(embeddings) final_hidden_state = final_hidden_state.transpose(0, 1) final_hidden_state = final_hidden_state.reshape(batch_size, -1) text_hidden = self.dropout_layer(final_hidden_state) logits = self.out_layer.forward(text_hidden) loss = self.loss(logits, labels.type_as(logits)) return loss, logits def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return [optimizer] def training_step(self, batch, _): inputs, labels = batch loss, logits = self(inputs, labels) return loss def validation_step(self, batch, _): inputs, labels = batch val_loss, logits = self(inputs, labels) if torch.max(labels) == 1: pred = torch.sigmoid(logits) else: pred = torch.softmax(logits, 1) self.valid_accuracy.update(pred, labels.long()) self.log("val_loss", val_loss, prog_bar=True) self.log("val_acc", self.valid_accuracy) def validation_epoch_end(self, outs): self.log("val_acc_epoch", self.valid_accuracy.compute(), prog_bar=True) def test_step(self, batch, _): inputs, labels = batch test_loss, logits = self(inputs, labels) if torch.max(labels) == 1: pred = torch.sigmoid(logits) else: pred = torch.softmax(logits, 1) self.test_accuracy.update(pred, labels.long()) self.log("test_loss", test_loss, prog_bar=True) self.log("test_acc", self.test_accuracy) def test_epoch_end(self, outs): self.log("test_acc_epoch", self.test_accuracy.compute(), prog_bar=True)