Beispiel #1
0
class SimpleModel(LightningModule):
    def __init__(self, vocab_size, embedding_dim=32):
        super().__init__()

        self.embeddings_layer = nn.Embedding(vocab_size, embedding_dim)
        self.loss = nn.BCEWithLogitsLoss()
        self.valid_accuracy = Accuracy()
        self.test_accuracy = Accuracy()

    def forward(self, inputs, labels):
        raise NotImplementedError("forward not implemented")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return [optimizer]

    def training_step(self, batch, _):
        inputs, labels = batch
        loss, logits = self(inputs, labels)
        return loss

    def validation_step(self, batch, _):
        inputs, labels = batch
        val_loss, logits = self(inputs, labels)
        if torch.max(labels) == 1:
            pred = torch.sigmoid(logits)
        else:
            pred = torch.softmax(logits, 1)
        self.valid_accuracy.update(pred, labels.long())
        self.log("val_loss", val_loss, prog_bar=True)
        self.log("val_acc", self.valid_accuracy)

    def validation_epoch_end(self, outs):
        self.log("val_acc_epoch", self.valid_accuracy.compute(), prog_bar=True)

    def test_step(self, batch, _):
        inputs, labels = batch
        test_loss, logits = self(inputs, labels)
        if torch.max(labels) == 1:
            pred = torch.sigmoid(logits)
        else:
            pred = torch.softmax(logits, 1)
        self.test_accuracy.update(pred, labels.long())
        self.log("test_loss", test_loss, prog_bar=True)
        self.log("test_acc", self.test_accuracy)

    def test_epoch_end(self, outs):
        self.log("test_acc_epoch", self.test_accuracy.compute(), prog_bar=True)
Beispiel #2
0
class Densenet121Lightning(BaseParticipantModel, pl.LightningModule):
    def __init__(self,
                 num_classes,
                 *args,
                 weights=None,
                 pretrain=True,
                 **kwargs):
        model = torchvision.models.densenet121(pretrained=pretrain)
        model.classifier = Linear(in_features=1024,
                                  out_features=num_classes,
                                  bias=True)
        super().__init__(*args, model=model, **kwargs)
        self.model = model
        self.accuracy = Accuracy()
        self.train_accuracy = Accuracy()
        self.criterion = CrossEntropyLoss(weight=weights)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        y = y.long()
        logits = self.model(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.log('train/acc/{}'.format(self.participant_name),
                 self.train_accuracy(preds, y))
        self.log('train/loss/{}'.format(self.participant_name), loss.item())
        return loss

    def test_step(self, test_batch, batch_idx):
        x, y = test_batch
        y = y.long()
        logits = self.model(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.accuracy.update(preds, y)
        GlobalConfusionMatrix().update(preds, y)
        return {'loss': loss}

    def test_epoch_end(self, outputs: List[Any]) -> None:
        loss_list = [o['loss'] for o in outputs]
        loss = torch.stack(loss_list)
        self.log(f'sample_num', self.accuracy.total.item())
        self.log(f'test/acc/{self.participant_name}', self.accuracy.compute())
        self.log(f'test/loss/{self.participant_name}', loss.mean().item())
class CNNLightning(BaseParticipantModel, pl.LightningModule):

    def __init__(self, only_digits=False, input_channels=1, *args, **kwargs):
        model = CNN_OriginalFedAvg(only_digits=only_digits, input_channels=input_channels)
        super().__init__(*args, model=model, **kwargs)
        self.model = model
        # self.model.apply(init_weights)
        self.accuracy = Accuracy()
        self.train_accuracy = Accuracy()

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        y = y.long()
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)

        self.log(f'train/acc/{self.participant_name}', self.train_accuracy(preds, y).item())
        self.log(f'train/loss/{self.participant_name}', loss.mean().item())
        return loss

    def test_step(self, test_batch, batch_idx):
        x, y = test_batch
        y = y.long()
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.accuracy.update(preds, y)
        return {'loss': loss}

    def test_epoch_end(
            self, outputs: List[Any]
    ) -> None:
        loss_list = [o['loss'] for o in outputs]
        loss = torch.stack(loss_list)
        self.log(f'sample_num', self.accuracy.total.item())
        self.log(f'test/acc/{self.participant_name}', self.accuracy.compute())
        self.log(f'test/loss/{self.participant_name}', loss.mean().item())
Beispiel #4
0
class FastTextLSTMModel(LightningModule):
    """
    Run LSTM over tokens FastText embeddings and take final hidden state, add linear projection and dropout
    """
    def __init__(self, ft_embedding_dim, hidden_dim=64):
        super().__init__()

        self.lstm_layer = nn.LSTM(ft_embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.dropout_layer = nn.Dropout(0.2)
        self.out_layer = nn.Linear(hidden_dim * 2, 1)

        self.loss = nn.BCEWithLogitsLoss()
        self.valid_accuracy = Accuracy()
        self.test_accuracy = Accuracy()

    def forward(self, embeddings, labels):
        """
        Forward pass
        :param embeddings: (batch_size, max_tokens_in_text, ft_embedding_dim)
        text -> ["hello", ",", "world",  ..] -> [9, 56, 72, ..] + padding or cutting to max sequence length
        :param labels: (batch_size, 1)
        :return: loss and logits
        """
        batch_size = embeddings.size(0)
        output, (final_hidden_state, final_cell_state) = self.lstm_layer(embeddings)
        final_hidden_state = final_hidden_state.transpose(0, 1)
        final_hidden_state = final_hidden_state.reshape(batch_size, -1)
        text_hidden = self.dropout_layer(final_hidden_state)
        logits = self.out_layer.forward(text_hidden)
        loss = self.loss(logits, labels.type_as(logits))
        return loss, logits

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return [optimizer]

    def training_step(self, batch, _):
        inputs, labels = batch
        loss, logits = self(inputs, labels)
        return loss

    def validation_step(self, batch, _):
        inputs, labels = batch
        val_loss, logits = self(inputs, labels)
        if torch.max(labels) == 1:
            pred = torch.sigmoid(logits)
        else:
            pred = torch.softmax(logits, 1)
        self.valid_accuracy.update(pred, labels.long())
        self.log("val_loss", val_loss, prog_bar=True)
        self.log("val_acc", self.valid_accuracy)

    def validation_epoch_end(self, outs):
        self.log("val_acc_epoch", self.valid_accuracy.compute(), prog_bar=True)

    def test_step(self, batch, _):
        inputs, labels = batch
        test_loss, logits = self(inputs, labels)
        if torch.max(labels) == 1:
            pred = torch.sigmoid(logits)
        else:
            pred = torch.softmax(logits, 1)
        self.test_accuracy.update(pred, labels.long())
        self.log("test_loss", test_loss, prog_bar=True)
        self.log("test_acc", self.test_accuracy)

    def test_epoch_end(self, outs):
        self.log("test_acc_epoch", self.test_accuracy.compute(), prog_bar=True)