class Solver(object):
    def __init__(self, config, train_loader, valid_loader, test_loader):
        # data loader
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader

        # Models
        self.unet = None
        self.optimizer = None
        self.img_ch = config['img_ch']
        self.output_ch = config['output_ch']
        self.criterion = torch.nn.BCELoss()  # binary cross entropy loss

        # Hyper-parameters
        self.lr = config['lr']
        self.beta1 = config['beta1']  # momentum1 in Adam
        self.beta2 = config['beta2']  # momentum2 in Adam

        # Training settings
        self.num_epochs = config['num_epochs']
        self.num_epochs_decay = config['num_epoches_decay']
        self.batch_size = config['batch_size']

        # Path
        self.model_path = config['model_path']
        self.result_path = config['result_path']

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.model_type = config['model_type']
        self.t = config['t']
        self.unet_path = os.path.join(
            self.model_path, '%s-%d-%.4f-%d.pkl' %
            (self.model_type, self.num_epochs, self.lr, self.num_epochs_decay))
        self.best_epoch = 0
        self.build_model()

    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=1, output_ch=1)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=1, output_ch=1, t=self.t)
            #init_weights(self.unet, 'normal')

        self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr,
                                    (self.beta1, self.beta2))
        self.unet.to(self.device)

    def train(self):
        """Print out the network information."""
        num_params = 0
        for p in self.unet.parameters():
            num_params += p.numel(
            )  # accumulate the number of mmodel parameters
        print("The number of parameters: {}".format(num_params))

        # ====================================== Training ===========================================#

        # network train
        if os.path.isfile(self.unet_path):
            # Load the pretrained Encoder
            self.unet.load_state_dict(torch.load(self.unet_path))
            print('%s is Successfully Loaded from %s' %
                  (self.model_type, self.unet_path))

        else:
            lr = self.lr
            best_unet_score = 0.0
            best_epoch = 0

            for epoch in range(self.num_epochs):
                self.unet.train(True)
                epoch_loss = 0

                acc = 0.  # Accuracy
                SE = 0.  # Sensitivity (Recall)
                SP = 0.  # Specificity
                PC = 0.  # Precision
                F1 = 0.  # F1 Score
                JS = 0.  # Jaccard Similarity
                DC = 0.  # Dice Coefficient
                length = 0

                for i, (images, GT) in enumerate(self.train_loader):
                    images, GT = images.to(self.device), GT.to(self.device)

                    # forward result
                    SR = self.unet(images)
                    SR_probs = torch.sigmoid(SR)
                    SR_flat = SR_probs.view(SR_probs.size(0),
                                            -1)  # size(0) is batch_size
                    GT_flat = GT.view(GT.size(0), -1)

                    loss = self.criterion(SR_flat, GT_flat)
                    epoch_loss += loss.item()

                    # Backprop + optimize
                    self.unet.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                    acc += get_accuracy(SR, GT)
                    SE += get_sensitivity(SR, GT)
                    SP += get_specificity(SR, GT)
                    PC += get_precision(SR, GT)
                    F1 += get_F1(SR, GT)
                    JS += get_JS(SR, GT)
                    DC += get_DC(SR, GT)
                    length = length + 1

                acc = acc / length
                SE = SE / length
                SP = SP / length
                PC = PC / length
                F1 = F1 / length
                JS = JS / length
                DC = DC / length

                # Print the log info
                print(
                    'Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f,'
                    ' F1: %.4f, JS: %.4f, DC: %.4f' %
                    (epoch + 1, self.num_epochs, epoch_loss, acc, SE, SP, PC,
                     F1, JS, DC))
                train_accuracy.append(acc)

                # Decay learning rate
                if (epoch + 1) > (self.num_epochs - self.num_epochs_decay):
                    lr -= (self.lr / float(self.num_epochs_decay))
                    for param_group in self.optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Decay learning rate to lr: {}.'.format(lr))

                # ===================================== Validation ====================================#
                self.unet.train(False)
                self.unet.eval()

                acc = 0.  # Accuracy
                SE = 0.  # Sensitivity (Recall)
                SP = 0.  # Specificity
                PC = 0.  # Precision
                F1 = 0.  # F1 Score
                JS = 0.  # Jaccard Similarity
                DC = 0.  # Dice Coefficient
                length = 0
                for i, (images, GT) in enumerate(self.valid_loader):
                    images, GT = images.to(self.device), GT.to(self.device)
                    SR = torch.sigmoid(self.unet(images))
                    acc += get_accuracy(SR, GT)
                    SE += get_sensitivity(SR, GT)
                    SP += get_specificity(SR, GT)
                    PC += get_precision(SR, GT)
                    F1 += get_F1(SR, GT)
                    JS += get_JS(SR, GT)
                    DC += get_DC(SR, GT)

                    length = length + 1

                acc = acc / length
                SE = SE / length
                SP = SP / length
                PC = PC / length
                F1 = F1 / length
                JS = JS / length
                DC = DC / length
                unet_score = JS + DC

                print('[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, '
                      'F1: %.4f, JS: %.4f, DC: %.4f' %
                      (acc, SE, SP, PC, F1, JS, DC))
                validation_accuracy.append(acc)

                if unet_score > best_unet_score:
                    best_unet_score = unet_score
                    self.best_epoch = epoch
                    best_unet = self.unet.state_dict(
                    )  # contain best parameters for each layer
                    print('Best %s model score : %.4f' %
                          (self.model_type, best_unet_score))
                    torch.save(best_unet, self.unet_path)

    def test(self):
        self.unet.load_state_dict(torch.load(self.unet_path))
        self.unet.eval()

        acc = 0.  # Accuracy
        SE = 0.  # Sensitivity (Recall)
        SP = 0.  # Specificity
        PC = 0.  # Precision
        F1 = 0.  # F1 Score
        JS = 0.  # Jaccard Similarity
        DC = 0.  # Dice Coefficient
        length = 0
        result = []
        for i, (images, GT) in enumerate(self.test_loader):
            images = images.to(self.device)
            GT = GT.to(self.device)
            SR = torch.sigmoid(self.unet(images))

            acc += get_accuracy(SR, GT)
            SE += get_sensitivity(SR, GT)
            SP += get_specificity(SR, GT)
            PC += get_precision(SR, GT)
            F1 += get_F1(SR, GT)
            JS += get_JS(SR, GT)
            DC += get_DC(SR, GT)

            length = length + 1

            SR = SR.to('cpu')
            SR = SR.detach().numpy()
            result.extend(SR)

        acc = acc / length
        SE = SE / length
        SP = SP / length
        PC = PC / length
        F1 = F1 / length
        JS = JS / length
        DC = DC / length
        unet_score = JS + DC

        reconstruct_image(self, np.array(result))

        f = open(os.path.join(self.result_path, 'result.csv'),
                 'a',
                 encoding='utf-8',
                 newline='')
        wr = csv.writer(f)
        wr.writerow([
            self.model_type,
            acc,
            SE,
            SP,
            PC,
            F1,
            JS,
            DC,
            self.lr,
            self.best_epoch,
            self.num_epochs,
            self.num_epochs_decay,
        ])
        f.close()
Пример #2
0
from utils import dense_crf
from utils import intersectionAndUnion
from PIL import Image
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 2
num_channels = 3
batch_size = 4
size = (256, 256)
root = "data/membrane/test"
img_file = search_file(root, [".png"])
# print(img_file)
if __name__ == "__main__":
    model = U_Net(num_channels, num_classes).to(device)
    model.load_state_dict(torch.load('UNet_weights_bilinear_weight.pth'))
    model.eval()
    with torch.no_grad():
        for i in range(1):
            print(img_file[i])
            input = cv2.imread(img_file[i], cv2.IMREAD_COLOR)
            input = cv2.resize(input, size)
            original_img = input
            print(
                os.path.join(
                    "data/membrane/result1",
                    os.path.splitext(os.path.basename(img_file[i]))[0] +
                    "_predict.png"), )
            label = cv2.imread(
                os.path.join(
                    "data/membrane/result1",
Пример #3
0
train_loader = DataLoader(dataset=train, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(dataset=val, batch_size=batch_size // 2, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test, batch_size=batch_size // 4, shuffle=True, num_workers=4)


from nissl_dataset import Nissl_mask_dataset
from network import U_Net
from network import ResAttU_Net


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

modelunet = U_Net(UnetLayer=5, img_ch=3, output_ch=4).to(device)
modelresunet = ResAttU_Net(UnetLayer=5,img_ch=3,output_ch=4).to(device)
modelunet.load_state_dict(torch.load('/gdrive/MyDrive/models/unet'), strict=False)

# output = modelunet(image.to(device))

modelresunet.load_state_dict(torch.load('/gdrive/MyDrive/models/resunet'), strict=False)

def gt_to_colorimg(masks):
    
    
    #colors = np.asarray([(201, 58, 64), (242, 207, 1), (0, 152, 75), (101, 172, 228)])#,(56, 34, 132), (160, 194, 56)])

    colors = np.asarray([(0,0,0), (255,0,0), (0,255,0), (0,0,255)])
    colorimg = np.ones((masks.shape[1], masks.shape[2], 3), dtype=np.float32) * 255
    channels, height, width = masks.shape

    for y in range(height):