class PneumoniaTrainer: def __init__( self, model_name: str, epochs: int = 100, config: Optional[dict] = None ): self.output_model = model_name + ".pth" self.train_set = PneumoniaDataset(SetType.train) self.train_loader = DataLoader( PneumoniaDataset(SetType.train), batch_size=16, shuffle=True, num_workers=8 ) self.val_loader = DataLoader( PneumoniaDataset(SetType.val, shuffle=False), batch_size=16, shuffle=False, num_workers=8, ) self.test_loader = DataLoader( PneumoniaDataset(SetType.test, shuffle=False), batch_size=16, shuffle=False, num_workers=8, ) self.config = { "pos_weight_bias": 0.5, "starting_lr": 1e-2, "momentum": 0.9, "decay": 5e-4, "lr_adjustment_factor": 0.3, "scheduler_patience": 15, "print_cadence": 100, "comment": "Added large dense layer.", "pos_weight": 1341 / 3875, # Number of negatives / positives. } self.epochs = epochs self.device = torch.device("cuda:0") self.writer = SummaryWriter(comment=self.config["comment"]) self.net = SimpleNet(1).to(self.device) self.criterion = nn.BCEWithLogitsLoss( pos_weight=torch.tensor(self.config["pos_weight"]) ) self.optimizer = optim.SGD( self.net.parameters(), lr=self.config["starting_lr"], # type: ignore momentum=self.config["momentum"], # type: ignore weight_decay=self.config["decay"], # type: ignore ) self.scheduler = ReduceLROnPlateau( self.optimizer, factor=self.config["lr_adjustment_factor"], # type: ignore mode="max", verbose=True, patience=self.config["scheduler_patience"], # type: ignore ) print("Trainer Initialized.") for dataset in [self.train_loader, self.test_loader, self.val_loader]: print(f"Size of set: {len(dataset)}") def train(self): training_pass = 0 for epoch in range(self.epochs): running_loss = 0.0 for i, (inputs, labels, metadata) in enumerate(self.train_loader): self.net.train() self.optimizer.zero_grad() outputs = self.net(inputs.float().to(self.device)) loss = self.criterion( outputs, labels.unsqueeze(1).float().to(self.device) ) loss.backward() self.optimizer.step() running_loss += loss.item() if i > 0 and i % self.config["print_cadence"] == 0: mean_loss = running_loss / self.config["print_cadence"] print( f'Epoch: {epoch}\tBatch: {i}\tLoss: {mean_loss}' ) self.writer.add_scalar( "Train/RunningLoss", mean_loss, training_pass, ) running_loss = 0.0 training_pass += 1 train_accuracy = self.log_training_metrics(epoch) self.log_validation_metrics(epoch) self.scheduler.step(train_accuracy) accuracy, metrics = self.calculate_accuracy(self.test_loader) self.writer.add_text("Test/Accuracy", f"{accuracy}") for key, val in metrics.items(): self.writer.add_text(f"Test/{key}", f"{val}") self.save_model() def log_training_metrics(self, epoch: int): accuracy, metrics = self.calculate_accuracy(self.train_loader) self.writer.add_scalar(f"Train/Accuracy", accuracy, epoch) for key, val in metrics.items(): self.writer.add_scalar(f"Train/{key}", val, epoch) return accuracy def log_validation_metrics(self, epoch: int): accuracy, metrics = self.calculate_accuracy(self.val_loader) self.writer.add_scalar("Validation/Accuracy", accuracy, epoch) for key, val in metrics.items(): self.writer.add_scalar(f"Validation/{key}", val, epoch) return accuracy def calculate_accuracy(self, loader: DataLoader): truth_list: list = [] pred_list: list = [] with torch.no_grad(): self.net.eval() correct = 0.0 total = 0.0 for inputs, labels, metadata in loader: outputs = self.net(inputs.float().to(self.device)) sigmoid = torch.nn.Sigmoid() preds = sigmoid(outputs) preds = np.round(preds.detach().cpu().squeeze(1)) pred_list.extend(preds) # type: ignore truth_list.extend(labels) total += labels.size(0) correct += preds.eq(labels.float()).sum().item() print(f"Correct:\t{correct}, Incorrect:\t{total-correct}") tn, fp, fn, tp = confusion_matrix(truth_list, pred_list).ravel() metrics = { "Recall": tp / (tp + fn), "Precision": tp / (tp + fp), "FalseNegativeRate": fn / (tn + fn), "FalsePositiveRate": fp / (tp + fp), } return correct / total, metrics def save_model(self): print("saving...") torch.save(self.net.state_dict(), self.output_model)
def main(): es_staged_data_index = "cifar-metadata-1" es_logging_index = "custom-net-cifar-12" output_model = es_logging_index + ".pth" es = Elasticsearch("localhost:9200") data = [ doc["_source"] for doc in list(scan(es, index=es_staged_data_index)) ] np.random.seed(42) np.random.shuffle(data) training_data = [x for x in data if "train" in x["set_type"]] testing_data = [x for x in data if "test" in x["set_type"]] print(f"Size of training set: {len(training_data)}") print(f"Size of testing set: {len(testing_data)}") # didnt use this time around. train_dataset_loader = _get_dataset_loader(training_data, transform=transform_train, shuffle=True) test_dataset_loader = _get_dataset_loader(testing_data) net = SimpleNet(10).cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4) # Train print("training...") for epoch in range(300): running_loss = 0.0 for i, (inputs, labels) in enumerate(train_dataset_loader): optimizer.zero_grad() outputs = net(inputs.float().cuda()) loss = criterion(outputs, labels.long().cuda()) loss.backward() optimizer.step() # print stats running_loss += loss.item() print_on = 100 if (i + 1) % print_on == 0: record = { "timestamp": datetime.utcnow().isoformat(), "cross-entropy-loss": running_loss / print_on, "model-name": "train-simplenet-8" } es.index(index=es_logging_index, body=record) print('[%d, %5d] loss %.3f' % (epoch + 1, i + 1, running_loss / (print_on + 1))) running_loss = 0.0 # Test if epoch + 1 % 10: print("testing...") with torch.no_grad(): correct = 0.0 total = 0.0 i = 0.0 for inputs, labels in test_dataset_loader: outputs = net(inputs.float().cuda()) #_, predicted = torch.max(outputs.data, 1) _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels.cuda()).sum().item() #correct += (predicted == labels.long().cuda()).sum().item() i += 1 test_accuracy = correct / total print(f"Test Accuracy: {test_accuracy}") print(f"Correct: {correct}, Incorrect: {total-correct}") record = { "accuracy": test_accuracy, "correct": correct, "incorrect": total - correct, "timestamp": datetime.utcnow().isoformat() } es.index(index=es_logging_index, body=record) # Save print("saving...") torch.save(net.state_dict(), output_model)