def train_model(cust_model, dataloaders, criterion, optimizer, num_epochs = 10, scheduler = None):
    start_time = time.time()
    val_acc_history = []
    best_acc = 0.0
    best_model_wts = copy.deepcopy(cust_model.state_dict())

    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch, num_epochs - 1))
        print("_"*15)
        for phase in ["train", "valid"]:
            if phase == "train":
                cust_model.train()
            if phase == "valid":
                cust_model.eval()
            running_loss = 0.0
            jaccard_acc = 0.0
            dice_acc = 0.0

            for input_img, labels in dataloaders[phase]:
                input_img = input_img.cuda() if use_cuda else input_img
                labels = labels.cuda() if use_cuda else labels

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase=="train"):
                    preds = cust_model(input_img)
                    loss = criterion(preds, labels)

                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * input_img.size(0)
                jaccard_acc += jaccard(labels, preds)
                dice_acc += dice(labels, preds)

            epoch_loss = running_loss / len(dataloaders[phase])
            aver_jaccard = jaccard_acc / len(dataloaders[phase])
            aver_dice = dice_acc / len(dataloaders[phase])

            print("| {} Loss: {:.4f} | Jaccard Average Acc: {:.4f} | Dice Average Acc: {:.4f} |".format(phase, epoch_loss, aver_jaccard, aver_dice))
            if phase == "valid" and aver_jaccard > best_acc:
                best_acc = aver_jaccard
                best_model_wts = copy.deepcopy(cust_model.state_dict)
                pass
            if phase == "valid":
                val_acc_history.append(aver_jaccard)
                pass
        print("="*15)
        print(" ")
    time_elapsed = time.time() - start_time
    print("Training complete in {:.0f}m {:.0f}s".format(time_elapsed//60, time_elapsed % 60))
    print("Best validation Accuracy: {:.4f}".format(best_acc))
    best_model_wts = copy.deepcopy(cust_model.state_dict())
    cust_model.load_state_dict(best_model_wts)
    return cust_model, val_acc_history
def train_model(cust_model,
                dataloaders,
                criterion,
                optimizer,
                num_epochs,
                scheduler=None):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    start_time = time.time()
    val_acc_history = []
    best_acc = 0.0
    best_model_wts = copy.deepcopy(cust_model)
    best_optimizer_wts = optim.Adam(best_model_wts.parameters(), lr=0.0001)
    best_optimizer_wts.load_state_dict(optimizer.state_dict())
    start_epoch = args["lastepoch"] + 1
    if (start_epoch > 1):
        filepath = "./checkpoint_epoch" + str(args["lastepoch"]) + ".pth"
        #filepath="ResNet34watershedplus_linknet_50.pt"
        cust_model, optimizer = load_checkpoint(cust_model, filepath)
        #cust_model = load_model(cust_model,filepath)
    for epoch in range(start_epoch - 1, num_epochs, 1):
        print("Epoch {}/{}".format(epoch + 1, num_epochs))
        print("_" * 15)
        for phase in ["train", "valid"]:
            if phase == "train":
                cust_model.train()
            if phase == "valid":
                cust_model.eval()
            running_loss = 0.0
            jaccard_acc = 0.0
            jaccard_acc_inter = 0.0
            jaccard_acc_contour = 0.0
            dice_loss = 0.0

            for input_img, labels, inter, contours in tqdm(
                    dataloaders[phase], total=len(dataloaders[phase])):
                #input_img = input_img.cuda() if use_cuda else input_img
                #labels = labels.cuda() if use_cuda else labels
                #inter = inter.cuda() if use_cuda else inter
                input_img = input_img.to(device)
                labels = labels.to(device)
                inter = inter.to(device)
                contours = contours.to(device)
                label_true = torch.cat([labels, inter, contours], 1)
                #label_true=labels
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    out = cust_model(input_img)
                    #preds = torch.sigmoid(out)
                    preds = out
                    #print(preds.shape)
                    loss = criterion(preds, label_true)
                    loss = loss.mean()

                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * input_img.size(0)
                #print(labels.shape)
                #preds=torch.FloatTensor(preds)
                #print(preds)
                preds = torch.cat(preds)  #for multiGPU
                #print(preds.shape)

                jaccard_acc += jaccard(
                    labels.to('cpu'), torch.sigmoid(preds.to('cpu'))
                )  # THIS IS THE ONE THAT STILL IS ACCUMULATION IN ONLY ONE GPU
                jaccard_acc_inter += jaccard(inter.to('cpu'),
                                             torch.sigmoid(preds.to('cpu')))
                jaccard_acc_contour += jaccard(contours.to('cpu'),
                                               torch.sigmoid(preds.to('cpu')))

                #dice_acc += dice(labels, preds)

            epoch_loss = running_loss / len(dataloaders[phase])
            print("| {} Loss: {:.4f} |".format(phase, epoch_loss))
            aver_jaccard = jaccard_acc / len(dataloaders[phase])
            aver_jaccard_inter = jaccard_acc_inter / len(dataloaders[phase])
            aver_jaccard_contour = jaccard_acc_contour / len(
                dataloaders[phase])
            #aver_dice = dice_acc / len(dataloaders[phase])
            #print("| {} Loss: {:.4f} | Jaccard Average Acc: {:.4f} | ".format(phase, epoch_loss, aver_jaccard))
            print(
                "| {} Loss: {:.4f} | Jaccard Average Acc: {:.4f} | Jaccard Average Acc inter: {:.4f}  | Jaccard Average Acc contour: {:.4f}| "
                .format(phase, epoch_loss, aver_jaccard, aver_jaccard_inter,
                        aver_jaccard_contour))
            print("_" * 15)
            if phase == "valid" and aver_jaccard > best_acc:
                best_acc = aver_jaccard
                best_acc_inter = aver_jaccard_inter  ## aver_jaccard_inter
                best_epoch_loss = epoch_loss
                #best_model_wts = copy.deepcopy(cust_model.state_dict)
                best_model_wts = copy.deepcopy(cust_model)
                best_optimizer_wts = optim.Adam(best_model_wts.parameters(),
                                                lr=0.0001)
                best_optimizer_wts.load_state_dict(optimizer.state_dict())
            if phase == "valid":
                val_acc_history.append(aver_jaccard)
        print("^" * 15)
        save_checkpoint(best_model_wts, best_optimizer_wts, epoch + 1,
                        best_epoch_loss, best_acc, best_acc_inter)
        print(" ")
        scheduler.step()
    time_elapsed = time.time() - start_time
    print("Training Complete in {:.0f}m {:.0f}s".format(
        time_elapsed // 60, time_elapsed % 60))
    #print("Best Validation Accuracy: {:.4f}".format(best_acc))
    #este no#best_model_wts = copy.deepcopy(cust_model.state_dict())
    cust_model.load_state_dict(best_model_wts.state_dict())
    return cust_model, val_acc_history
Пример #3
0
import torch

from data_loader import Melanoma_Train_Validation_DataLoader
from torchvision import transforms
from fcn_naive_model import fcn_model

from helper import jaccard, dice

use_cuda = torch.cuda.is_available()
segm_model = fcn_model()
train_loader, validation_loader = Melanoma_Train_Validation_DataLoader(
    batch_size=4,
    data_transforms=transforms.Compose(
        [transforms.Resize([512, 512]),
         transforms.ToTensor()]),
    num_workers=2)
if use_cuda:
    segm_model.cuda()
segm_model.train()

for i, sample in enumerate(validation_loader):
    img, label_img = sample
    img = img.cuda() if use_cuda else img
    label_img = label_img.cuda() if use_cuda else label_img
    output = segm_model(img)
    out = torch.sigmoid(output)
    print("The Jaccard accuracy is: {:.4f}".format(jaccard(label_img, out)))
    print("The Dice accuracy is: {:.4f}".format(dice(label_img, out)))
use_cuda = torch.cuda.is_available()

# Hyperparameters
thrs_list = np.linspace(0.1, 0.9, 400) 
batch_size = 10
num_workers = 10

_, validation_loader = Melanoma_Train_Validation_DataLoader(batch_size = batch_size,  data_transforms=transforms.Compose([transforms.Resize([512, 512]), transforms.ToTensor()]), num_workers=num_workers)

model = fcn_model().cuda() if use_cuda else fcn_model()
model = load_model(model, model_dir="fcn_15epch_interpol.pt", map_location_device="gpu") if use_cuda else load_model(model, model_dir="fcn_15epch_interpol.pt")
columns = ["Threshold", "Accuracy"]
thrs_df = pd.DataFrame(columns = columns)
thrs_df["Threshold"] = thrs_list

for thrs in thrs_list:
    jaccard_acc = 0.0
    for input_img, label_img in validation_loader:
        input_img = input_img.cuda() if use_cuda else input_img
        label_img = label_img.cuda() if use_cuda else label_img
        outputs = model(input_img)
        preds = torch.sigmoid(outputs)
        jaccard_acc += jaccard(label_img, (preds > thrs).float())
    print("Threshold {:.8f} | Jaccard Accuracy: {:.8f}".format(thrs, jaccard_acc / len(validation_loader)))
    thrs_df["Accuracy"] = (jaccard_acc / len(validation_loader))

idx = thrs_df.loc[thrs_df["Accuracy"] == max(thrs_df["Accuracy"])]
optimal_thrs = thrs_df["Threshold"].loc[idx.index.values]
thrs_df.to_csv("accuracyVSthreshold.csv")
print("Optimal Threshold is {:.8f} found with Accuracy of {:.4f}".format(optimal_thrs.values, max(thrs_df["Accuracy"])))
Пример #5
0
def train_model(cust_model, dataloaders, criterion, optimizer, num_epochs, scheduler=None):
    start_time = time.time()
    val_acc_history = []
    best_acc = 0.0
    best_model_wts = copy.deepcopy(cust_model.state_dict())

    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch+1, num_epochs))
        print("_"*15)
        for phase in ["train", "valid"]:
            if phase == "train":
                cust_model.train()
            if phase == "valid":
                cust_model.eval()
            running_loss = 0.0
            jaccard_acc = 0.0
            jaccard_acc_inter = 0.0
            dice_loss = 0.0

            for input_img, labels, inter in tqdm(dataloaders[phase], total=len(dataloaders[phase])):
                input_img = input_img.cuda() if use_cuda else input_img
                labels = labels.cuda() if use_cuda else labels
                inter = inter.cuda() if use_cuda else inter
                label_true=torch.cat([labels,inter], 1)
                #label_true=inter
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    out = cust_model(input_img)
                    #preds = torch.sigmoid(out) 
                    preds=out               
                    loss = criterion(preds, label_true)

                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * input_img.size(0)
                jaccard_acc += jaccard(labels, torch.sigmoid(preds))
                jaccard_acc_inter += jaccard(inter, torch.sigmoid(preds))
                #dice_acc += dice(labels, preds)
            
            epoch_loss = running_loss / len(dataloaders[phase])
            aver_jaccard = jaccard_acc / len(dataloaders[phase])
            aver_jaccard_inter = jaccard_acc_inter / len(dataloaders[phase])
            #aver_dice = dice_acc / len(dataloaders[phase])

            print("| {} Loss: {:.4f} | Jaccard Average Acc: {:.4f} | Jaccard Average Acc inter: {:.4f} |".format(phase, epoch_loss, aver_jaccard,aver_jaccard_inter))
            print("_"*15)
            if phase == "valid" and aver_jaccard > best_acc:
                best_acc = aver_jaccard
                best_model_wts = copy.deepcopy(cust_model.state_dict)
            if phase == "valid":
                val_acc_history.append(aver_jaccard)
        print("^"*15)
        print(" ")
        scheduler.step()
    time_elapsed = time.time() - start_time
    print("Training Complete in {:.0f}m {:.0f}s".format(time_elapsed//60, time_elapsed % 60))
    print("Best Validation Accuracy: {:.4f}".format(best_acc))
    best_model_wts = copy.deepcopy(cust_model.state_dict())
    cust_model.load_state_dict(best_model_wts)
    return cust_model, val_acc_history