Example #1
0
class PneumoniaTrainer:
    def __init__(
        self, model_name: str, epochs: int = 100, config: Optional[dict] = None
    ):
        self.output_model = model_name + ".pth"
        self.train_set = PneumoniaDataset(SetType.train)
        self.train_loader = DataLoader(
            PneumoniaDataset(SetType.train), batch_size=16, shuffle=True, num_workers=8
        )
        self.val_loader = DataLoader(
            PneumoniaDataset(SetType.val, shuffle=False),
            batch_size=16,
            shuffle=False,
            num_workers=8,
        )
        self.test_loader = DataLoader(
            PneumoniaDataset(SetType.test, shuffle=False),
            batch_size=16,
            shuffle=False,
            num_workers=8,
        )
        self.config = {
            "pos_weight_bias": 0.5,
            "starting_lr": 1e-2,
            "momentum": 0.9,
            "decay": 5e-4,
            "lr_adjustment_factor": 0.3,
            "scheduler_patience": 15,
            "print_cadence": 100,
            "comment": "Added large dense layer.",
            "pos_weight": 1341 / 3875,  # Number of negatives / positives.
        }

        self.epochs = epochs
        self.device = torch.device("cuda:0")
        self.writer = SummaryWriter(comment=self.config["comment"])
        self.net = SimpleNet(1).to(self.device)
        self.criterion = nn.BCEWithLogitsLoss(
            pos_weight=torch.tensor(self.config["pos_weight"])
        )
        self.optimizer = optim.SGD(
            self.net.parameters(),
            lr=self.config["starting_lr"],  # type: ignore
            momentum=self.config["momentum"],  # type: ignore
            weight_decay=self.config["decay"],  # type: ignore
            )
        self.scheduler = ReduceLROnPlateau(
            self.optimizer,
            factor=self.config["lr_adjustment_factor"],  # type: ignore
            mode="max",
            verbose=True,
            patience=self.config["scheduler_patience"],  # type: ignore
        )

        print("Trainer Initialized.")
        for dataset in [self.train_loader, self.test_loader, self.val_loader]:
            print(f"Size of set: {len(dataset)}")

    def train(self):
        training_pass = 0
        for epoch in range(self.epochs):
            running_loss = 0.0
            for i, (inputs, labels, metadata) in enumerate(self.train_loader):
                self.net.train()
                self.optimizer.zero_grad()
                outputs = self.net(inputs.float().to(self.device))
                loss = self.criterion(
                    outputs, labels.unsqueeze(1).float().to(self.device)
                )
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item()
                if i > 0 and i % self.config["print_cadence"] == 0:
                    mean_loss = running_loss / self.config["print_cadence"]
                    print(
                        f'Epoch: {epoch}\tBatch: {i}\tLoss: {mean_loss}'
                    )
                    self.writer.add_scalar(
                        "Train/RunningLoss",
                        mean_loss,
                        training_pass,
                    )
                    running_loss = 0.0
                training_pass += 1
            train_accuracy = self.log_training_metrics(epoch)
            self.log_validation_metrics(epoch)
            self.scheduler.step(train_accuracy)
        accuracy, metrics = self.calculate_accuracy(self.test_loader)
        self.writer.add_text("Test/Accuracy", f"{accuracy}")
        for key, val in metrics.items():
            self.writer.add_text(f"Test/{key}", f"{val}")
        self.save_model()

    def log_training_metrics(self, epoch: int):
        accuracy, metrics = self.calculate_accuracy(self.train_loader)
        self.writer.add_scalar(f"Train/Accuracy", accuracy, epoch)
        for key, val in metrics.items():
            self.writer.add_scalar(f"Train/{key}", val, epoch)
        return accuracy

    def log_validation_metrics(self, epoch: int):
        accuracy, metrics = self.calculate_accuracy(self.val_loader)
        self.writer.add_scalar("Validation/Accuracy", accuracy, epoch)
        for key, val in metrics.items():
            self.writer.add_scalar(f"Validation/{key}", val, epoch)
        return accuracy

    def calculate_accuracy(self, loader: DataLoader):
        truth_list: list = []
        pred_list: list = []
        with torch.no_grad():
            self.net.eval()
            correct = 0.0
            total = 0.0
            for inputs, labels, metadata in loader:
                outputs = self.net(inputs.float().to(self.device))
                sigmoid = torch.nn.Sigmoid()
                preds = sigmoid(outputs)
                preds = np.round(preds.detach().cpu().squeeze(1))
                pred_list.extend(preds)  # type: ignore
                truth_list.extend(labels)
                total += labels.size(0)
                correct += preds.eq(labels.float()).sum().item()
        print(f"Correct:\t{correct}, Incorrect:\t{total-correct}")

        tn, fp, fn, tp = confusion_matrix(truth_list, pred_list).ravel()
        metrics = {
            "Recall": tp / (tp + fn),
            "Precision": tp / (tp + fp),
            "FalseNegativeRate": fn / (tn + fn),
            "FalsePositiveRate": fp / (tp + fp),
        }

        return correct / total, metrics

    def save_model(self):
        print("saving...")
        torch.save(self.net.state_dict(), self.output_model)
def main():
    es_staged_data_index = "cifar-metadata-1"
    es_logging_index = "custom-net-cifar-12"
    output_model = es_logging_index + ".pth"
    es = Elasticsearch("localhost:9200")
    data = [
        doc["_source"] for doc in list(scan(es, index=es_staged_data_index))
    ]

    np.random.seed(42)
    np.random.shuffle(data)
    training_data = [x for x in data if "train" in x["set_type"]]
    testing_data = [x for x in data if "test" in x["set_type"]]
    print(f"Size of training set: {len(training_data)}")
    print(f"Size of testing set: {len(testing_data)}")

    # didnt use this time around.
    train_dataset_loader = _get_dataset_loader(training_data,
                                               transform=transform_train,
                                               shuffle=True)
    test_dataset_loader = _get_dataset_loader(testing_data)

    net = SimpleNet(10).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=1e-3,
                          momentum=0.9,
                          weight_decay=5e-4)
    # Train
    print("training...")
    for epoch in range(300):

        running_loss = 0.0

        for i, (inputs, labels) in enumerate(train_dataset_loader):
            optimizer.zero_grad()
            outputs = net(inputs.float().cuda())
            loss = criterion(outputs, labels.long().cuda())
            loss.backward()
            optimizer.step()

            # print stats
            running_loss += loss.item()
            print_on = 100
            if (i + 1) % print_on == 0:
                record = {
                    "timestamp": datetime.utcnow().isoformat(),
                    "cross-entropy-loss": running_loss / print_on,
                    "model-name": "train-simplenet-8"
                }
                es.index(index=es_logging_index, body=record)
                print('[%d, %5d] loss %.3f' % (epoch + 1, i + 1, running_loss /
                                               (print_on + 1)))
                running_loss = 0.0

        # Test
        if epoch + 1 % 10:
            print("testing...")
            with torch.no_grad():
                correct = 0.0
                total = 0.0
                i = 0.0
                for inputs, labels in test_dataset_loader:
                    outputs = net(inputs.float().cuda())
                    #_, predicted = torch.max(outputs.data, 1)
                    _, predicted = outputs.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels.cuda()).sum().item()
                    #correct += (predicted == labels.long().cuda()).sum().item()
                    i += 1

                test_accuracy = correct / total
            print(f"Test Accuracy: {test_accuracy}")
            print(f"Correct: {correct}, Incorrect: {total-correct}")
            record = {
                "accuracy": test_accuracy,
                "correct": correct,
                "incorrect": total - correct,
                "timestamp": datetime.utcnow().isoformat()
            }
            es.index(index=es_logging_index, body=record)

    # Save
    print("saving...")
    torch.save(net.state_dict(), output_model)