def main(train_dir, val_dir, checkpoint_dir, batch_size, image_size=512, num_epochs=10, checkpoint_name=None, num_workers=1, pin_memory=True, log_dir="logs", model_name=None, train_csv=None, val_csv=None): # declare datasets train_ds = DataFolder(root_dir=train_dir, transform=transform(image_size, is_training=True), csv_path=train_csv) val_ds = DataFolder(root_dir=val_dir, transform=transform(image_size, is_training=False), csv_path=val_csv) train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True) val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True) #init model model = MainModel(128, model_name) # configure parameter loss_fn = nn.CrossEntropyLoss() model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=1e-4) scaler = torch.cuda.amp.GradScaler() # checkpoint = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()} # save_checkpoint(checkpoint, os.path.join(checkpoint_dir, f"checkpoint_initialilze.pth.tar")) # return if checkpoint_name: ckp_path = os.path.join(checkpoint_dir, checkpoint_name) load_checkpoint(torch.load(ckp_path), model, optimizer) check_accuracy(val_loader, model, device) #training for epoch in range(num_epochs): train_fn(train_loader, model, optimizer, loss_fn, scaler, device, epoch, log_dir=log_dir) check_accuracy(val_loader, model, device) checkpoint = { 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() } save_checkpoint( checkpoint, os.path.join(checkpoint_dir, f"checkpoint_{epoch}.pth.tar"))
def main(test_dir, checkpoint_path, batch_size, num_workers=1, pin_memory=True, test_csv=None, model_name='efficientnet-b3'): # declare datasets test_ds = DataFolder(root_dir=test_dir, transform=transform(is_training=False), is_test=True, csv_path=test_csv) test_loader = DataLoader(test_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True) #init model model = MainModel(test_ds.__num_class__(), model_name) model = model.to(device) # load checkpoint load_checkpoint(torch.load(checkpoint_path), model) model.eval() iterator = tqdm(test_loader) num_correct = 0 num_samples = 0 preds = [] groundtruths = [] print(test_ds.class_names) with torch.no_grad(): for x, y, image_paths in iterator: #convert to device x = x.to(device=device) y = y.to(device=device) # inference scores = torch.sigmoid(model(x)) # get prediction max_score = torch.argmax(scores, dim=1) # add to global comparing value preds += max_score.to("cpu").numpy().tolist() groundtruths += y.to("cpu").numpy().tolist() #calculate score predictions = max_score.float() num_correct += (predictions == y).sum() num_samples += predictions.shape[0] iterator.set_postfix( accuracy=f'{float(num_correct) / float(num_samples) * 100:.2f}' ) # break print( classification_report(groundtruths, preds, zero_division=0, target_names=test_ds.class_names))