def main():
    global args, logger
    args = parser.parse_args()
    set_prefix(args.prefix, __file__)
    model = model_builder()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # accelerate the speed of training
    cudnn.benchmark = True

    train_loader, val_loader = load_dataset()
    # class_names=['LESION', 'NORMAL']
    class_names = train_loader.dataset.class_names
    print(class_names)

    # learning rate decay per epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    since = time.time()
    print('-' * 10)
    for epoch in range(args.epochs):
        # adjust weight once unet can be nearly seen as an identical mapping
        exp_lr_scheduler.step()
        train(train_loader, model, optimizer, epoch)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    with torch.no_grad():
        validate(model, train_loader, val_loader)
    # save_typical_result(model)
    torch.save(model.state_dict(), add_prefix(args.prefix, 'locator.pkl'))
    write(vars(args), add_prefix(args.prefix, 'paras.txt'))
 def save(self, saved_path, name, inputs):
     saved_path += '_single'
     if not os.path.exists(saved_path):
         os.makedirs(saved_path)
     output = self.auto_encoder(inputs)
     output = self.restore(output)
     scipy.misc.imsave(add_prefix(saved_path, name), output)
     print('file %s is saved to %s successfully.' %(name, add_prefix(saved_path, name)))
Exemplo n.º 3
0
    def save_init_paras(self):
        if not os.path.exists(self.prefix):
            os.makedirs(self.prefix)

        torch.save(self.unet.state_dict(),
                   add_prefix(self.prefix, 'init_g_para.pkl'))
        torch.save(self.d.state_dict(),
                   add_prefix(self.prefix, 'init_d_para.pkl'))
        print('save initial model parameters successfully')
Exemplo n.º 4
0
def main():
    global args, min_loss, best_acc
    args = parser.parse_args()
    device_counts = torch.cuda.device_count()
    print('there is %d gpus in usage' % (device_counts))
    # save source script
    set_prefix(args.prefix, __file__)
    model = model_selector(args.model_type)
    print(model)
    if args.cuda:
        model = DataParallel(model).cuda()
    else:
        raise RuntimeError('there is no gpu')

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    # accelerate the speed of training
    cudnn.benchmark = True

    train_loader, val_loader = load_dataset()
    # class_names=['LESION', 'NORMAL']
    class_names = train_loader.dataset.class_names
    print(class_names)
    criterion = nn.BCELoss().cuda()

    # learning rate decay per epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                           step_size=args.step_size,
                                           gamma=args.gamma)
    since = time.time()
    print('-' * 10)
    for epoch in range(args.epochs):
        exp_lr_scheduler.step()
        train(train_loader, model, optimizer, criterion, epoch)
        cur_loss, cur_acc = validate(model, val_loader, criterion)
        is_best = cur_loss < min_loss
        best_loss = min(cur_loss, min_loss)
        if is_best:
            best_acc = cur_acc
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model_type,
                'state_dict': model.state_dict(),
                'min_loss': best_loss,
                'acc': best_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    check_point = torch.load(add_prefix(args.prefix, args.best_model_path))
    print('min_loss=%.4f, best_acc=%.4f' %
          (check_point['min_loss'], check_point['acc']))
    write(vars(args), add_prefix(args.prefix, 'paras.txt'))
Exemplo n.º 5
0
def main():
    global args, best_acc
    args = parser.parse_args()
    # save source script
    set_prefix(args.prefix, __file__)
    model = models.densenet121(pretrained=False, num_classes=2)
    if args.cuda:
        model = DataParallel(model).cuda()
    else:
        warnings.warn('there is no gpu')

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    # accelerate the speed of training
    cudnn.benchmark = True

    train_loader, val_loader = load_dataset()
    # class_names=['LESION', 'NORMAL']
    class_names = train_loader.dataset.classes
    print(class_names)
    if args.is_focal_loss:
        print('try focal loss!!')
        criterion = FocalLoss().cuda()
    else:
        criterion = nn.CrossEntropyLoss().cuda()

    # learning rate decay per epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                           step_size=args.step_size,
                                           gamma=args.gamma)
    since = time.time()
    print('-' * 10)
    for epoch in range(args.epochs):
        exp_lr_scheduler.step()
        train(train_loader, model, optimizer, criterion, epoch)
        cur_accuracy = validate(model, val_loader, criterion)
        is_best = cur_accuracy > best_acc
        best_acc = max(cur_accuracy, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': 'resnet18',
                'state_dict': model.state_dict(),
                'best_accuracy': best_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    # compute validate meter such as confusion matrix
    compute_validate_meter(model, add_prefix(args.prefix,
                                             args.best_model_path), val_loader)
    # save running parameter setting to json
    write(vars(args), add_prefix(args.prefix, 'paras.txt'))
def validate(model, epoch, data_loader):
    for i, (lesion_data, _, lesion_names, _, real_data, _, normal_names,
            _) in enumerate(data_loader):
        if i > 2 and epoch != args.epochs:
            break
        if args.cuda:
            lesion_data, real_data = lesion_data.cuda(), real_data.cuda()
        phase = 'lesion'
        prefix_path = '%s/epoch_%d/%s' % (args.prefix, epoch, phase)

        nums = min(args.batch_size, lesion_data.size(0))
        for idx in range(nums):
            single_image = lesion_data[idx:(idx + 1), :, :, :]
            single_name = lesion_names[idx]
            save_single_image(prefix_path, model, single_name, single_image)
            if args.debug:
                break

        phase = 'normal'
        prefix_path = '%s/epoch_%d/%s' % (args.prefix, epoch, phase)
        nums = min(args.batch_size, real_data.size(0))
        for idx in range(nums):
            single_image = real_data[idx:(idx + 1), :, :, :]
            single_name = normal_names[idx]
            save_single_image(prefix_path, model, single_name, single_image)
            if args.debug:
                break

    prefix_path = '%s/epoch_%d' % (args.prefix, epoch)
    torch.save(model.state_dict(), add_prefix(prefix_path, 'g.pkl'))
    print('save model parameters successfully when epoch=%d' % epoch)
def save_single_image(saved_path, model, name, inputs):
    """
    save unet output as a form of image
    """
    if not os.path.exists(saved_path):
        os.makedirs(saved_path)
    output = model(inputs)

    left = restore(inputs)
    right = restore(output)

    diff = np.where(left > right, left - right,
                    right - left).clip(0, 255).astype(np.uint8)
    plt.figure(num='unet result', figsize=(8, 8))
    plt.subplot(2, 2, 1)
    plt.title('source image')
    plt.imshow(left)
    plt.axis('off')
    plt.subplot(2, 2, 2)
    plt.title('unet output')
    plt.imshow(right)
    plt.axis('off')
    plt.subplot(2, 2, 3)
    plt.imshow(rgb2gray(diff), cmap='jet')
    plt.colorbar(orientation='horizontal')
    plt.title('difference in heatmap')
    plt.axis('off')
    plt.subplot(2, 2, 4)
    plt.imshow(rgb2gray(diff.clip(0, 32)), cmap='jet')
    plt.colorbar(orientation='horizontal')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(add_prefix(saved_path, name))
    plt.close()
Exemplo n.º 8
0
def load_pretrained_model(prefix):
    checkpoint = torch.load(add_prefix(prefix, 'model_best.pth.tar'))
    model = resnet18(is_ptrtrained=False)
    print('load pretrained resnet18 successfully.')
    model.load_state_dict(remove_prefix(checkpoint['state_dict']))
    # print('best acc=%.4f' % checkpoint['best_accuracy'])
    return model
Exemplo n.º 9
0
    def save_image(self, saved_path, name, inputs):
        """
        save unet output as a form of image
        """
        if not os.path.exists(saved_path):
            os.makedirs(saved_path)
        output = self.unet(inputs)

        left = self.restore(inputs)
        right = self.restore(output)
        # The above two lines of code are wrong.To be precisely,errors will occur when the value of var left is less than
        # the value of var right.For example,left=217,right=220,then result is 253 after abs operation.
        diff = np.where(left > right, left - right,
                        right - left).clip(0, 255).astype(np.uint8)
        plt.figure(num='unet result', figsize=(8, 8))
        plt.subplot(2, 2, 1)
        plt.title('source image')
        plt.imshow(left)
        plt.axis('off')
        plt.subplot(2, 2, 2)
        plt.title('unet output')
        plt.imshow(right)
        plt.axis('off')
        plt.subplot(2, 2, 3)
        plt.imshow(rgb2gray(diff), cmap='jet')
        plt.colorbar(orientation='horizontal')
        plt.title('difference in heatmap')
        plt.axis('off')
        plt.subplot(2, 2, 4)
        plt.imshow(rgb2gray(diff.clip(0, 32)), cmap='jet')
        plt.colorbar(orientation='horizontal')
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(add_prefix(saved_path, name))
        plt.close()
def save_single_image(name, inputs, output, label, phase):
    """
    save single image local
    :param name: name
    :param inputs: network input
    :param output: network output
    :param label: image label: 'lesion' or 'normal'
    :param phase: image source: 'training' or 'validate' dataset
    :return:
    """
    left = restore(inputs)
    right = restore(output)
    plt.figure(num='unet result', figsize=(8, 8))
    plt.subplot(2, 2, 1)
    plt.title('source image')
    plt.imshow(left)
    plt.axis('off')

    plt.subplot(2, 2, 2)
    plt.title('output image')
    plt.imshow(right)
    plt.axis('off')

    diff = np.where(left > right, left - right,
                    right - left).clip(0, 255).astype(np.uint8)
    plt.subplot(2, 2, 3)
    plt.imshow(rgb2gray(diff), cmap='jet')
    plt.colorbar()

    plt.tight_layout()
    plt.savefig(add_prefix(args.prefix, '%s/%s/%s' % (phase, label, name)))
    plt.close()
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    refence:
        http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig(add_prefix(args.prefix, 'confusion_matrix.png'))
    plt.close()
def save_single_image(label, name, inputs, output, unet, phase):
    left = restore(inputs)
    right = restore(unet)

    # save network input image
    plt.figure(num='unet result', figsize=(8, 8))
    plt.subplot(2, 2, 1)
    plt.title('source image: %s' % label)
    plt.imshow(left)
    plt.axis('off')
    predict = F.softmax(output, dim=1)
    predict = to_np(predict).flatten()

    # save network output image
    plt.subplot(2, 2, 2)
    plt.title('lesion: %.2f, normal: %.2f' % (predict[0], predict[1]))
    plt.imshow(right)
    plt.axis('off')

    # save difference directly
    diff = np.where(left > right, left - right, right - left).clip(0, 255)
    plt.subplot(2, 2, 3)
    plt.imshow(rgb2gray(diff), cmap='jet')
    plt.colorbar()
    plt.title('difference in abs gray')

    plt.tight_layout()
    plt.savefig(add_prefix(args.prefix, '%s/%s/%s' % (phase, label, name)))
    plt.close()
Exemplo n.º 13
0
    def validate(self, epoch):
        """
        eval mode
        """
        real_data_score = []
        fake_data_score = []
        for i, (lesion_data, _, lesion_names, _, real_data, _, normal_names,
                _) in enumerate(self.dataloader):
            if i > 2:
                break
            if self.use_gpu:
                lesion_data, real_data = lesion_data.cuda(), real_data.cuda()
            phase = 'lesion_data'
            prefix_path = '%s/epoch_%d/%s' % (self.prefix, epoch, phase)
            lesion_output = self.d(self.unet(lesion_data))
            fake_data_score += list(
                lesion_output.squeeze().cpu().data.numpy().flatten())

            for idx in range(self.batch_size):
                single_image = lesion_data[idx:(idx + 1), :, :, :]
                single_name = lesion_names[idx]
                self.save_image(prefix_path, single_name, single_image)
                if self.debug:
                    break

            phase = 'normal_data'
            prefix_path = '%s/epoch_%d/%s' % (self.prefix, epoch, phase)
            normal_output = self.d(real_data)
            real_data_score += list(
                normal_output.squeeze().cpu().data.numpy().flatten())

            for idx in range(self.batch_size):
                single_image = real_data[idx:(idx + 1), :, :, :]
                single_name = normal_names[idx]
                self.save_image(prefix_path, single_name, single_image)
                if self.debug:
                    break

        prefix_path = '%s/epoch_%d' % (self.prefix, epoch)

        self.plot_hist('%s/score_distribution.png' % prefix_path,
                       real_data_score, fake_data_score)
        torch.save(self.unet.state_dict(), add_prefix(prefix_path, 'g.pkl'))
        torch.save(self.d.state_dict(), add_prefix(prefix_path, 'd.pkl'))
        print('save model parameters successfully when epoch=%d' % epoch)
def plt_roc(test_y, probas_y, plot_micro=False, plot_macro=False):
    assert isinstance(test_y, list) and isinstance(
        probas_y, list), 'the type of input must be list'
    skplt.metrics.plot_roc(test_y,
                           probas_y,
                           plot_micro=plot_micro,
                           plot_macro=plot_macro)
    plt.savefig(add_prefix(args.prefix, 'roc_auc_curve.png'))
    plt.close()
def main(prefix, epoch, data_dir):
    saved_path = '../%s/dice_loss%s/' % (prefix, epoch)
    criterion = DiceLoss(prefix, epoch, data_dir)
    resutls = dict()
    # note the range of threshold: if the value is too small,the dice loss will be high but wrong because entire images will tend to 1.
    for thresh in range(1, 256):
        avg_dice_loss = criterion(thresh)
        resutls[thresh] = avg_dice_loss
        print('avg dice loss=%.4f,thresh=%d' % (avg_dice_loss, thresh))
    write(resutls, add_prefix(saved_path, 'results.txt'))
Exemplo n.º 16
0
def load_pretrained_model(pretrained_path, model_type):
    checkpoint = torch.load(add_prefix(pretrained_path, 'model_best.pth.tar'))
    if model_type == 'vgg':
        model = vgg19(pretrained=False, num_classes=2)
        print('load vgg successfully.')
    elif model_type == 'resnet':
        model = resnet18(is_ptrtrained=False)
        print('load resnet18 successfully.')
    else:
        raise ValueError('')
    model.load_state_dict(remove_prefix(checkpoint['state_dict']))
    return model
def load_pretrained_model(prefix, model_type):
    if model_type == 'resnet':
        model = resnet18(is_ptrtrained=False)
    elif model_type == 'vgg':
        model = vgg19(num_classes=2, pretrained=False)
    else:
        raise ValueError('')

    checkpoint = torch.load(add_prefix(prefix, 'model_best.pth.tar'))
    print('load pretrained model successfully.')
    model.load_state_dict(remove_prefix(checkpoint['state_dict']))
    print('best acc=%.4f' % checkpoint['best_accuracy'])
    return model
def main():
    global args, logger
    args = parser.parse_args()
    # logger = Logger(add_prefix(args.prefix, 'logs'))
    set_prefix(args.prefix, __file__)
    model = UNet(3, depth=5, in_channels=3)
    print(model)
    print('load unet with depth=5')
    if args.cuda:
        model = DataParallel(model).cuda()
    else:
        raise RuntimeError('there is no gpu')
    criterion = nn.L1Loss(reduce=False).cuda()
    print('use l1_loss')
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    # accelerate the speed of training
    cudnn.benchmark = True

    data_loader = get_dataloader()
    # class_names=['LESION', 'NORMAL']
    # class_names = data_loader.dataset.class_names
    # print(class_names)

    since = time.time()
    print('-' * 10)
    for epoch in range(1, args.epochs + 1):
        train(data_loader, model, optimizer, criterion, epoch)
        if epoch % 40 == 0:
            validate(model, epoch, data_loader)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    validate(model, args.epochs, data_loader)
    # save model parameter
    torch.save(model.state_dict(),
               add_prefix(args.prefix, 'identical_mapping.pkl'))
    # save running parameter setting to json
    write(vars(args), add_prefix(args.prefix, 'paras.txt'))
    def contrast(self, saved_path, name, inputs):
        """
        save unet output as a form of image
        """
        if not os.path.exists(saved_path):
            os.makedirs(saved_path)
        if os.path.exists(add_prefix(saved_path, name)):
            return
        output = self.auto_encoder(inputs)

        left = self.restore(inputs)
        right = self.restore(output)

        diff = np.where(left > right, left - right, right - left).clip(0, 255).astype(np.uint8)
        plt.figure(num='unet result', figsize=(8, 8))
        plt.subplot(2, 2, 1)
        plt.title('source image')
        plt.imshow(left)
        plt.axis('off')
        plt.subplot(2, 2, 2)
        plt.title('unet output')
        plt.imshow(right)
        plt.axis('off')
        plt.subplot(2, 2, 3)
        plt.imshow(rgb2gray(diff), cmap='jet')
        plt.colorbar(orientation='horizontal')
        plt.title('difference in heatmap')
        plt.axis('off')
        plt.subplot(2, 2, 4)
        plt.imshow(rgb2gray(diff.clip(0, 32)), cmap='jet')
        plt.colorbar(orientation='horizontal')
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(add_prefix(saved_path, name))
        print('file %s is saved to %s successfully.' %(name, add_prefix(saved_path, name)))
        plt.close()
Exemplo n.º 20
0
    def __init__(self, args):
        # initialize hyper-parameters
        self.data = args.data
        self.gan_type = args.gan_type
        self.d_depth = args.d_depth
        self.dowmsampling = args.dowmsampling
        self.gpu_counts = args.gpu_counts
        self.power = args.power
        self.batch_size = args.batch_size
        self.use_gpu = torch.cuda.is_available()
        self.u_depth = args.u_depth
        self.is_pretrained_unet = args.is_pretrained_unet
        self.pretrain_unet_path = args.pretrain_unet_path

        self.lr = args.lr
        self.debug = args.debug
        self.prefix = args.prefix
        self.interval = args.interval
        self.n_update_gan = args.n_update_gan
        self.epochs = args.epochs
        self.gamma = args.gamma
        self.beta1 = args.beta1

        self.training_strategies = args.training_strategies
        self.epoch_interval = 1 if self.debug else 50

        self.logger = Logger(add_prefix(self.prefix, 'tensorboard'))
        # normalize the images between [-1 and 1]
        self.mean = [0.5, 0.5, 0.5]
        self.std = [0.5, 0.5, 0.5]
        self.dataloader = self.get_dataloader()
        self.d = get_discriminator(self.gan_type, self.d_depth,
                                   self.dowmsampling)
        self.unet = self.get_unet()

        self.log_lst = []

        if self.use_gpu:
            self.unet = DataParallel(self.unet).cuda()
            self.d = DataParallel(self.d).cuda()
        else:
            raise RuntimeWarning('there is no gpu available.')
        self.save_init_paras()
        self.get_optimizer()
        self.save_hyperparameters(args)
Exemplo n.º 21
0
 def save_hyperparameters(self, args):
     write(vars(args), add_prefix(self.prefix, 'para.txt'))
     print('save hyperparameters successfully.')
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    # save training state after each epoch
    torch.save(state, add_prefix(args.prefix, filename))
    if is_best:
        shutil.copyfile(add_prefix(args.prefix, filename),
                        add_prefix(args.prefix, args.best_model_path))
Exemplo n.º 23
0
 def load_config(self):
     return read(add_prefix(self.prefix, 'para.txt'))
def load_pretrained_model(prefix):
    checkpoint = torch.load(add_prefix(prefix, 'model_best.pth.tar'))
    model = vgg19(num_classes=2, pretrained=False)
    print('load pretrained vgg19 successfully.')
    model.load_state_dict(remove_prefix(checkpoint['state_dict']))
    return model
Exemplo n.º 25
0
 def save_running_script(self, script_path):
     """
     save the main running script to get differences between scripts
     """
     copy(script_path, add_prefix(self.prefix, script_path.split('/')[-1]))
Exemplo n.º 26
0
 def save_log(self):
     write_list(self.log_lst, add_prefix(self.prefix, 'log.txt'))
     print('save running log successfully')