def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    m = math.ceil(len(testset) / cf.batch_size)
    for batch_idx, (inputs_value, targets) in enumerate(testloader):
        x = inputs_value.view(-1, inputs, resize, resize).repeat(cf.num_samples, 1, 1, 1)
        y = targets.repeat(cf.num_samples)
        if use_cuda:
            x, y = x.cuda(), y.cuda()
        with torch.no_grad():
            x, y = Variable(x), Variable(y)
        outputs, kl = net.probforward(x)

        if cf.beta_type is "Blundell":
            beta = 2 ** (m - (batch_idx + 1)) / (2 ** m - 1)
        elif cf.beta_type is "Soenderby":
            beta = min(epoch / (cf.num_epochs // 4), 1)
        elif cf.beta_type is "Standard":
            beta = 1 / m
        else:
            beta = 0

        loss = vi(outputs,y,kl,beta)

        test_loss += loss.data[0]
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(y.data).cpu().sum()

    # Save checkpoint when best model
    acc =(100*correct/total)/cf.num_samples
    print("\n| Validation Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%" %(epoch, loss.data[0], acc))
    utils.writeLogs(str("\n| Validation Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%" %(epoch, loss.data[0], acc)))
    test_diagnostics_to_write = {'Validation Epoch':epoch, 'Loss':loss.data[0], 'Accuracy': acc}
    with open(logfile, 'a') as lf:
        lf.write(str(test_diagnostics_to_write))

    if acc > best_acc:
        print('| Saving Best model...\t\t\tTop1 = %.2f%%' %(acc))
        utils.writeLogs(str('| Saving Best model...\t\t\tTop1 = %.2f%%' %(acc)))
        state = {
                'net':net if use_cuda else net,
                'acc':acc,
                'epoch':epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        save_point = './checkpoint/'+args.dataset+os.sep
        if not os.path.isdir(save_point):
            os.mkdir(save_point)
        torch.save(state, save_point+file_name+str(cf.num_samples)+'.t7')
        best_acc = acc
Beispiel #2
0
def train(epoch):
    net.train()
    train_loss = 0
    correct = 0
    trainLoss = []
    total = 0
    optimizer = optim.Adam(net.parameters(), lr=cf.dynamic_lr(cf.lr, epoch), weight_decay=cf.weight_decay)

    print('\n=> Training Epoch #%d, LR=%.4f' %(epoch, cf.dynamic_lr(cf.lr, epoch)))
    utils.writeLogs('\n=> Training Epoch #%d, LR=%.4f' %(epoch, cf.dynamic_lr(cf.lr, epoch)))

    m = math.ceil(len(testset) / cf.batch_size)
    for batch_idx, (inputs_value, targets) in enumerate(trainloader):
        targets = torch.tensor(targets)
        x = inputs_value.view(-1, inputs, resize_origa, resize_origa)#.repeat(cf.num_samples, 1, 1, 1)
        y = targets#.repeat(cf.num_samples)
        if use_cuda:
            x, y = x.cuda(), y.cuda() # GPU settings

        if cf.beta_type is "Blundell":
            beta = 2 ** (m - (batch_idx + 1)) / (2 ** m - 1)
        elif cf.beta_type is "Soenderby":
            beta = min(epoch / (cf.num_epochs // 4), 1)
        elif cf.beta_type is "Standard":
            beta = 1 / m
        else:
            beta = 0
        # Forward Propagation
        x, y = Variable(x), Variable(y)
        outputs, kl = net.probforward(x)
        loss = vi(outputs, y, kl, beta)  # Loss
        optimizer.zero_grad()
        loss.backward()  # Backward Propagation
        optimizer.step() # Optimizer update

        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)

        total += targets.size(0)
        correct += predicted.eq(y.data).cpu().sum()

        sys.stdout.write('\r')
        sys.stdout.write('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f%%'
                %(epoch, cf.num_epochs, batch_idx+1,
                    (len(trainset)//cf.batch_size)+1, loss.item(), (100*(correct.item()/total))))
        utils.writeLogs(str('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f%%'
                %(epoch, cf.num_epochs, batch_idx+1,
                    (len(trainset)//cf.batch_size)+1, loss.item(), (100*(correct.item()/total)))))

        sys.stdout.flush()

    trainLoss.append(loss.item())
    diagnostics_to_write =  {'Epoch': epoch, 'Loss': loss.item(), 'Accuracy': (100*(correct.item()/total))}
    utils.writeLogs(str(diagnostics_to_write))
    with open(logfile, 'a') as lf:
        lf.write(str(diagnostics_to_write))
def printMetrics(testTargets,testPredicts, epoch):

    print(epoch + " f1_score:" + str(f1_score(testTargets, testPredicts, average="macro")))
    utils.writeLogs(epoch + " f1_score:" + str(f1_score(testTargets, testPredicts, average="macro")))
    print(epoch + " overall precision:" + str(precision_score(testTargets, testPredicts, average="macro")))
    utils.writeLogs(epoch + " overall precision:" + str(precision_score(testTargets, testPredicts, average="macro")))
    print(epoch + " overall recall : " + str(recall_score(testTargets, testPredicts, average="macro")))
    utils.writeLogs(epoch + " overall recall : " + str(recall_score(testTargets, testPredicts, average="macro")))
    fpr, tpr, thresholds = roc_curve(testTargets, testPredicts)

    roc_auc = roc_auc_score(testTargets, testPredicts)

    print(epoch + " False positive : " + str(fpr))
    utils.writeLogs(epoch + " False positive : " + str(fpr))
    print (epoch + " True positive" + str(tpr))
    utils.writeLogs(epoch + " True positive : " + str(tpr))
    print(epoch + " Thresholds : " + str(thresholds))
    utils.writeLogs(epoch + " Thresholds : " + str(thresholds))

    print(epoch + " ROC_AUC : " + str(roc_auc))
    utils.writeLogs(epoch + " ROC_AUC : " + str(roc_auc))
    confuse = confusion_matrix(testTargets, testPredicts)
    print(epoch + " Confusion Matrix : " + str(confuse))
    utils.writeLogs(epoch + " Confusion Matrix : " + str(confuse))
    target_names = ['Healthy', 'Glaucoma']
    print(epoch +  str(classification_report(testTargets, testPredicts, target_names=target_names, digits=4)))
    utils.writeLogs(epoch + str(classification_report(testTargets, testPredicts, target_names=target_names, digits=4)))
    if epoch == "Final":
        #ROC curve
        plt.figure()
        lw = 2
        plt.plot(fpr, tpr, color='darkorange', lw=lw,
                 label='ROC curve (area = %0.4f)' % roc_auc)
        plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver operating characteristic')
        plt.legend(loc="lower right")
        plt.savefig('./plots/' + 'roc.png', format='png', dpi=300)
        plt.show()
        plt.close()

        #Confusion Matrix
        heatmapData = pd.DataFrame(confuse, index=[i for i in ['TrueHealthy', 'TrueGlaucoma']],
                                   columns=[j for j in ['PredictedHealthy', 'PredictedGlaucoma']])
        plt.figure(figsize=(4, 4))
        sns.heatmap(heatmapData, annot=True, fmt='d')
        plt.savefig('./plots/' + 'heatmap.png', format='png', dpi=300)
        plt.show()
        plt.close()

        #Training loss Curve

        plt.plot(list(range(cf.num_epochs)) ,np.array(trainLoss))
        plt.savefig('./plots/' + 'trainLoss.png', format='png', dpi=300)
        plt.show()
        plt.close()
# Hyper Parameter settings
use_cuda = torch.cuda.is_available()
#torch.cuda.set_device(1)
best_acc = 0
resize=32
resize_origa=227

trainLoss = []
testLoss = []
testPredicts = []
testTargets = []


# Data Uplaod
print('\n[Phase 1] : Data Preparation')
utils.writeLogs('\n[Phase 1] : Data Preparation')

transform_train = transforms.Compose([
    transforms.Resize((resize, resize)),
    transforms.RandomCrop(32, padding=4),
    #transforms.RandomHorizontalFlip(),
    #CIFAR10Policy(),
    transforms.ToTensor(),
    transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]),
])  # meanstd transformation

transform_test = transforms.Compose([
    transforms.Resize((resize, resize)),
    transforms.RandomCrop(32, padding=4),
    #transforms.RandomHorizontalFlip(),
    #CIFAR10Policy(),
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    m = math.ceil(len(testset) / cf.batch_size)
    for batch_idx, (inputs_value, targets) in enumerate(testloader):
        x = inputs_value.view(-1, inputs, resize_origa,
                              resize_origa)  #.repeat(cf.num_samples, 1, 1, 1)
        y = targets  #.repeat(cf.num_samples)
        if use_cuda:
            x, y = x.cuda(), y.cuda()
        with torch.no_grad():
            x, y = Variable(x), Variable(y)
        outputs, kl = net.probforward(x)

        if cf.beta_type is "Blundell":
            beta = 2**(m - (batch_idx + 1)) / (2**m - 1)
        elif cf.beta_type is "Soenderby":
            beta = min(epoch / (cf.num_epochs // 4), 1)
        elif cf.beta_type is "Standard":
            beta = 1 / m
        else:
            beta = 0

        loss = vi(outputs, y, kl, beta)

        test_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(y.data).cpu().sum()

    print(predicted)
    print(y.data)
    print("overall f1_score:" +
          str(f1_score(y.data, predicted, average="macro")))
    utils.writeLogs("overalll f1_score:" +
                    str(f1_score(y.data, predicted, average="macro")))
    print("overall precision:" +
          str(precision_score(y.data, predicted, average="macro")))
    utils.writeLogs("overall precision:" +
                    str(precision_score(y.data, predicted, average="macro")))
    print("overall recall : " +
          str(recall_score(y.data, predicted, average="macro")))
    utils.writeLogs("overall recall : " +
                    str(recall_score(y.data, predicted, average="macro")))
    fpr, tpr, thresholds = roc_curve(y.data, predicted)
    roc_stat = 0
    try:
        roc_auc = roc_auc_score(y.data, predicted)
        print("ROC_AUC" + str(roc_auc))
        utils.writeLogs("ROC_AUC" + str(roc_auc))
        roc_stat = 1
    except ValueError:
        print("Same class in targets. Error in computing ROC_AUC")
        utils.writeLogs("Same class in targets. Error in computing ROC_AUC")
        roc_stat = 0
        pass

    print("False positive" + str(fpr))
    utils.writeLogs("False positive" + str(fpr))
    print("True positive" + str(tpr))
    utils.writeLogs("True positive" + str(tpr))
    print("Thresholds" + str(thresholds))
    utils.writeLogs("True positive" + str(thresholds))

    confuse = confusion_matrix(y.data, predicted)
    print("Confusion Matrix" + str(confuse))
    utils.writeLogs("Confusion Matrix" + str(confuse))
    target_names = ['Healthy', 'Glaucoma']
    print(
        classification_report(y.data,
                              predicted,
                              target_names=target_names,
                              digits=4))
    utils.writeLogs(
        str(
            classification_report(y.data,
                                  predicted,
                                  target_names=target_names,
                                  digits=4)))
    # Save checkpoint when best model
    acc = (100 * correct / total) / cf.num_samples
    print("\n| Validation Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%" %
          (epoch, loss.item(), acc))
    utils.writeLogs(
        str("\n| Validation Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%" %
            (epoch, loss.item(), acc)))
    test_diagnostics_to_write = {
        'Validation Epoch': epoch,
        'Loss': loss.item(),
        'Accuracy': acc
    }
    with open(logfile, 'a') as lf:
        lf.write(str(test_diagnostics_to_write))

    if acc > best_acc:
        print('| Saving Best model...\t\t\tTop1 = %.2f%%' % (acc))
        utils.writeLogs(
            str('| Saving Best model...\t\t\tTop1 = %.2f%%' % (acc)))
        state = {
            'net': net if use_cuda else net,
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        save_point = './checkpoint/' + args.dataset + os.sep
        if not os.path.isdir(save_point):
            os.mkdir(save_point)
        torch.save(state, save_point + file_name + str(cf.num_samples) + '.t7')
        best_acc = acc

    if epoch == cf.num_epochs:
        if roc_stat == 1:
            plt.figure()
            lw = 2
            plt.plot(fpr,
                     tpr,
                     color='darkorange',
                     lw=lw,
                     label='ROC curve (area = %0.4f)' % roc_auc)
            plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
            plt.xlim([0.0, 1.0])
            plt.ylim([0.0, 1.05])
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title('Receiver operating characteristic')
            plt.legend(loc="lower right")
            plt.savefig('./plots/' + 'roc.png', format='png', dpi=300)
            plt.show()
            plt.close()

            heatmapData = pd.DataFrame(
                confuse,
                index=[i for i in ['TrueHealthy', 'TrueGlaucoma']],
                columns=[j for j in ['PredictedHealthy', 'PredictedGlaucoma']])
            plt.figure(figsize=(4, 4))
            sns.heatmap(heatmapData, annot=True, fmt='d')
            plt.savefig('./plots/' + 'heatmap.png', format='png', dpi=300)
            plt.show()
            plt.close()