Ejemplo n.º 1
0
def train(cycle_num,
          dirs,
          path_to_net,
          plotter,
          batch_size=12,
          test_split=0.3,
          random_state=666,
          epochs=100,
          learning_rate=0.0001,
          momentum=0.9,
          num_folds=5,
          num_slices=155,
          n_classes=4):
    """
    Applies training on the network
        Args: 
            cycle_num (int): number of cycle in n-fold (num_folds) cross validation
            dirs (string): path to dataset subject directories 
            path_to_net (string): path to directory where to save network
            plotter (callable): visdom plotter
            batch_size - default (int): batch size
            test_split - default (float): percentage of test split 
            random_state - default (int): seed for k-fold cross validation
            epochs - default (int): number of epochs
            learning_rate - default (float): learning rate 
            momentum - default (float): momentum
            num_folds - default (int): number of folds in cross validation
            num_slices - default (int): number of slices per volume
            n_classes - default (int): number of classes (regions)
    """
    print('Setting started', flush=True)

    # Creating data indices
    # arange len of list of subject dirs
    indices = np.arange(len(glob.glob(dirs + '*')))
    test_indices, trainset_indices = get_test_indices(indices, test_split)
    # kfold index generator
    for cv_num, (train_indices, val_indices) in enumerate(
            get_train_cv_indices(trainset_indices, num_folds, random_state)):
        # splitted the 5-fold CV in 5 jobs
        if cv_num != int(cycle_num):
            continue

        net = U_Net()
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        num_GPU = torch.cuda.device_count()
        if num_GPU > 1:
            print('Let us use {} GPUs!'.format(num_GPU), flush=True)
            net = nn.DataParallel(net)
        net.to(device)
        criterion = nn.CrossEntropyLoss()
        if cycle_num % 2 == 0:
            optimizer = optim.SGD(net.parameters(),
                                  lr=learning_rate,
                                  momentum=momentum)
        else:
            optimizer = optim.Adam(net.parameters(), lr=learning_rate)

        scheduler = ReduceLROnPlateau(optimizer, threshold=1e-6, patience=0)

        print('cv cycle number: ', cycle_num, flush=True)
        start = time.time()
        print('Start Train and Val loading', flush=True)

        MRIDataset_train = dataset.MRIDataset(dirs, train_indices)

        MRIDataset_val = dataset.MRIDataset(dirs, val_indices)

        datalengths = {
            'train': len(MRIDataset_train),
            'val': len(MRIDataset_val)
        }
        dataloaders = {
            'train': get_dataloader(MRIDataset_train, batch_size, num_GPU),
            'val': get_dataloader(MRIDataset_val, batch_size, num_GPU)
        }
        print('Train and Val loading took: ', time.time() - start, flush=True)
        # make loss and acc history for train and val separatly
        # Setup Metrics
        running_metrics_val = runningScore(n_classes)
        running_metrics_train = runningScore(n_classes)
        val_loss_meter = averageMeter()
        train_loss_meter = averageMeter()
        itr = 0
        iou_best = 0.
        for epoch in tqdm(range(epochs), desc='Epochs'):
            print('Epoch: ', epoch + 1, flush=True)
            phase = 'train'
            print('Phase: ', phase, flush=True)
            start = time.time()
            # Set model to training mode
            net.train()
            # Iterate over data.
            for i, data in tqdm(enumerate(dataloaders[phase]),
                                desc='Data Iteration ' + phase):
                if (i + 1) % 100 == 0:
                    print('Number of Iteration [{}/{}]'.format(
                        i + 1, int(datalengths[phase] / batch_size)),
                          flush=True)
                # get the inputs
                inputs = data['mri_data'].to(device)
                GT = data['seg'].to(device)
                subject_slice_path = data['subject_slice_path']
                # Clear all accumulated gradients
                optimizer.zero_grad()
                # Predict classes using inputs from the train set
                SR = net(inputs)
                # Compute the loss based on the predictions and
                # actual segmentation
                loss = criterion(SR, GT)
                # Backpropagate the loss
                loss.backward()
                # Adjust parameters according to the computed
                # gradients
                # -- weight update
                optimizer.step()
                # Trake and plot metrics and loss, and save network
                predictions = SR.data.max(1)[1].cpu().numpy()
                GT_cpu = GT.data.cpu().numpy()
                running_metrics_train.update(GT_cpu, predictions)
                train_loss_meter.update(loss.item(), n=1)
                if (i + 1) % 100 == 0:
                    itr += 1
                    score, class_iou = running_metrics_train.get_scores()
                    for k, v in score.items():
                        plotter.plot(k, 'itr', phase, k, itr, v)
                    for k, v in class_iou.items():
                        print('Class {} IoU: {}'.format(k, v), flush=True)
                        plotter.plot(
                            str(k) + ' Class IoU', 'itr', phase,
                            str(k) + ' Class IoU', itr, v)
                    print('Loss Train', train_loss_meter.avg, flush=True)
                    plotter.plot('Loss', 'itr', phase, 'Loss Train', itr,
                                 train_loss_meter.avg)
            print('Phase {} took {} s for whole {}set!'.format(
                phase,
                time.time() - start, phase),
                  flush=True)

            # Validation Phase
            phase = 'val'
            print('Phase: ', phase, flush=True)
            start = time.time()
            # Set model to evaluation mode
            net.eval()
            start = time.time()
            with torch.no_grad():
                # Iterate over data.
                for i, data in tqdm(enumerate(dataloaders[phase]),
                                    desc='Data Iteration ' + phase):
                    if (i + 1) % 100 == 0:
                        print('Number of Iteration [{}/{}]'.format(
                            i + 1, int(datalengths[phase] / batch_size)),
                              flush=True)
                    # get the inputs
                    inputs = data['mri_data'].to(device)
                    GT = data['seg'].to(device)
                    subject_slice_path = data['subject_slice_path']
                    # Clear all accumulated gradients
                    optimizer.zero_grad()
                    # Predict classes using inputs from the train set
                    SR = net(inputs)
                    # Compute the loss based on the predictions and
                    # actual segmentation
                    loss = criterion(SR, GT)
                    # Trake and plot metrics and loss
                    predictions = SR.data.max(1)[1].cpu().numpy()
                    GT_cpu = GT.data.cpu().numpy()
                    running_metrics_val.update(GT_cpu, predictions)
                    val_loss_meter.update(loss.item(), n=1)
                    if (i + 1) % 100 == 0:
                        itr += 1
                        score, class_iou = running_metrics_val.get_scores()
                        for k, v in score.items():
                            plotter.plot(k, 'itr', phase, k, itr, v)
                        for k, v in class_iou.items():
                            print('Class {} IoU: {}'.format(k, v), flush=True)
                            plotter.plot(
                                str(k) + ' Class IoU', 'itr', phase,
                                str(k) + ' Class IoU', itr, v)
                        print('Loss Val', val_loss_meter.avg, flush=True)
                        plotter.plot('Loss ', 'itr', phase, 'Loss Val', itr,
                                     val_loss_meter.avg)
                if (epoch + 1) % 10 == 0:
                    if score['Mean IoU'] > iou_best:
                        save_net(path_to_net, batch_size, epoch, cycle_num,
                                 train_indices, val_indices, test_indices, net,
                                 optimizer)
                        iou_best = score['Mean IoU']
                    save_output(epoch, path_to_net, subject_slice_path,
                                SR.data.cpu().numpy(), GT_cpu)
                print('Phase {} took {} s for whole {}set!'.format(
                    phase,
                    time.time() - start, phase),
                      flush=True)
            # Call the learning rate adjustment function after every epoch
            scheduler.step(val_loss_meter.avg)
    # save network after training
    save_net(path_to_net,
             batch_size,
             epochs,
             cycle_num,
             train_indices,
             val_indices,
             test_indices,
             net,
             optimizer,
             iter_num=None)
class Solver(object):
    def __init__(self, config, train_loader, valid_loader, test_loader):
        # data loader
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader

        # Models
        self.unet = None
        self.optimizer = None
        self.img_ch = config['img_ch']
        self.output_ch = config['output_ch']
        self.criterion = torch.nn.BCELoss()  # binary cross entropy loss

        # Hyper-parameters
        self.lr = config['lr']
        self.beta1 = config['beta1']  # momentum1 in Adam
        self.beta2 = config['beta2']  # momentum2 in Adam

        # Training settings
        self.num_epochs = config['num_epochs']
        self.num_epochs_decay = config['num_epoches_decay']
        self.batch_size = config['batch_size']

        # Path
        self.model_path = config['model_path']
        self.result_path = config['result_path']

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.model_type = config['model_type']
        self.t = config['t']
        self.unet_path = os.path.join(
            self.model_path, '%s-%d-%.4f-%d.pkl' %
            (self.model_type, self.num_epochs, self.lr, self.num_epochs_decay))
        self.best_epoch = 0
        self.build_model()

    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=1, output_ch=1)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=1, output_ch=1, t=self.t)
            #init_weights(self.unet, 'normal')

        self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr,
                                    (self.beta1, self.beta2))
        self.unet.to(self.device)

    def train(self):
        """Print out the network information."""
        num_params = 0
        for p in self.unet.parameters():
            num_params += p.numel(
            )  # accumulate the number of mmodel parameters
        print("The number of parameters: {}".format(num_params))

        # ====================================== Training ===========================================#

        # network train
        if os.path.isfile(self.unet_path):
            # Load the pretrained Encoder
            self.unet.load_state_dict(torch.load(self.unet_path))
            print('%s is Successfully Loaded from %s' %
                  (self.model_type, self.unet_path))

        else:
            lr = self.lr
            best_unet_score = 0.0
            best_epoch = 0

            for epoch in range(self.num_epochs):
                self.unet.train(True)
                epoch_loss = 0

                acc = 0.  # Accuracy
                SE = 0.  # Sensitivity (Recall)
                SP = 0.  # Specificity
                PC = 0.  # Precision
                F1 = 0.  # F1 Score
                JS = 0.  # Jaccard Similarity
                DC = 0.  # Dice Coefficient
                length = 0

                for i, (images, GT) in enumerate(self.train_loader):
                    images, GT = images.to(self.device), GT.to(self.device)

                    # forward result
                    SR = self.unet(images)
                    SR_probs = torch.sigmoid(SR)
                    SR_flat = SR_probs.view(SR_probs.size(0),
                                            -1)  # size(0) is batch_size
                    GT_flat = GT.view(GT.size(0), -1)

                    loss = self.criterion(SR_flat, GT_flat)
                    epoch_loss += loss.item()

                    # Backprop + optimize
                    self.unet.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                    acc += get_accuracy(SR, GT)
                    SE += get_sensitivity(SR, GT)
                    SP += get_specificity(SR, GT)
                    PC += get_precision(SR, GT)
                    F1 += get_F1(SR, GT)
                    JS += get_JS(SR, GT)
                    DC += get_DC(SR, GT)
                    length = length + 1

                acc = acc / length
                SE = SE / length
                SP = SP / length
                PC = PC / length
                F1 = F1 / length
                JS = JS / length
                DC = DC / length

                # Print the log info
                print(
                    'Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f,'
                    ' F1: %.4f, JS: %.4f, DC: %.4f' %
                    (epoch + 1, self.num_epochs, epoch_loss, acc, SE, SP, PC,
                     F1, JS, DC))
                train_accuracy.append(acc)

                # Decay learning rate
                if (epoch + 1) > (self.num_epochs - self.num_epochs_decay):
                    lr -= (self.lr / float(self.num_epochs_decay))
                    for param_group in self.optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Decay learning rate to lr: {}.'.format(lr))

                # ===================================== Validation ====================================#
                self.unet.train(False)
                self.unet.eval()

                acc = 0.  # Accuracy
                SE = 0.  # Sensitivity (Recall)
                SP = 0.  # Specificity
                PC = 0.  # Precision
                F1 = 0.  # F1 Score
                JS = 0.  # Jaccard Similarity
                DC = 0.  # Dice Coefficient
                length = 0
                for i, (images, GT) in enumerate(self.valid_loader):
                    images, GT = images.to(self.device), GT.to(self.device)
                    SR = torch.sigmoid(self.unet(images))
                    acc += get_accuracy(SR, GT)
                    SE += get_sensitivity(SR, GT)
                    SP += get_specificity(SR, GT)
                    PC += get_precision(SR, GT)
                    F1 += get_F1(SR, GT)
                    JS += get_JS(SR, GT)
                    DC += get_DC(SR, GT)

                    length = length + 1

                acc = acc / length
                SE = SE / length
                SP = SP / length
                PC = PC / length
                F1 = F1 / length
                JS = JS / length
                DC = DC / length
                unet_score = JS + DC

                print('[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, '
                      'F1: %.4f, JS: %.4f, DC: %.4f' %
                      (acc, SE, SP, PC, F1, JS, DC))
                validation_accuracy.append(acc)

                if unet_score > best_unet_score:
                    best_unet_score = unet_score
                    self.best_epoch = epoch
                    best_unet = self.unet.state_dict(
                    )  # contain best parameters for each layer
                    print('Best %s model score : %.4f' %
                          (self.model_type, best_unet_score))
                    torch.save(best_unet, self.unet_path)

    def test(self):
        self.unet.load_state_dict(torch.load(self.unet_path))
        self.unet.eval()

        acc = 0.  # Accuracy
        SE = 0.  # Sensitivity (Recall)
        SP = 0.  # Specificity
        PC = 0.  # Precision
        F1 = 0.  # F1 Score
        JS = 0.  # Jaccard Similarity
        DC = 0.  # Dice Coefficient
        length = 0
        result = []
        for i, (images, GT) in enumerate(self.test_loader):
            images = images.to(self.device)
            GT = GT.to(self.device)
            SR = torch.sigmoid(self.unet(images))

            acc += get_accuracy(SR, GT)
            SE += get_sensitivity(SR, GT)
            SP += get_specificity(SR, GT)
            PC += get_precision(SR, GT)
            F1 += get_F1(SR, GT)
            JS += get_JS(SR, GT)
            DC += get_DC(SR, GT)

            length = length + 1

            SR = SR.to('cpu')
            SR = SR.detach().numpy()
            result.extend(SR)

        acc = acc / length
        SE = SE / length
        SP = SP / length
        PC = PC / length
        F1 = F1 / length
        JS = JS / length
        DC = DC / length
        unet_score = JS + DC

        reconstruct_image(self, np.array(result))

        f = open(os.path.join(self.result_path, 'result.csv'),
                 'a',
                 encoding='utf-8',
                 newline='')
        wr = csv.writer(f)
        wr.writerow([
            self.model_type,
            acc,
            SE,
            SP,
            PC,
            F1,
            JS,
            DC,
            self.lr,
            self.best_epoch,
            self.num_epochs,
            self.num_epochs_decay,
        ])
        f.close()