def validate(self, val_loader, epoch_index):
        path = 'checkpoints\\ss1_1\\checkpoint_' + str(
            epoch_index) + '.pth.tar'
        device = torch.device('cuda')
        model = screening.screening()
        model.to(device)
        model.eval()
        state = torch.load(path)

        # load params
        model.load_state_dict(state['state_dict'])

        losses = AverageMeter()

        optimizer = optim.SGD(model.parameters(), lr=0.03, momentum=0.9)

        # optimizer = optim.Adadelta(model.parameters(), lr=1.0, rho=0.95, eps=1e-06, weight_decay=0)

        # set a progress bar
        pbar = tqdm(enumerate(val_loader), total=len(val_loader))

        for i, (images, labels) in pbar:
            images = images.float()
            images = Variable(images.cuda())

            images = images.unsqueeze(dim=1)

            if epoch_index == 1:
                images = nn.init.normal_(images, mean=0, std=0.01)

            # compute output
            optimizer.zero_grad()

            outputs = model(images)

            labels = np.squeeze(labels)
            labels = Variable(labels.cuda())
            labels = labels.long()

            outputs = np.squeeze(outputs)
            outputs = torch.nn.functional.softmax(outputs)

            loss = torch.nn.functional.cross_entropy(outputs, labels)
            losses.update(loss.data, images.size(0))

            loss.backward()
            optimizer.step()

            pbar.set_description('[validate] - BATCH LOSS: %.4f/ %.4f(avg) ' %
                                 (losses.val, losses.avg))

        return losses.avg
    def train(self, train_loader, model, epoch, num_epochs, checkpoint_ss1):
        if (epoch == 0):
            model = screening.screening()
            device = torch.device('cuda')
            model.to(device)
            checkpoint = torch.load(checkpoint_ss1, map_location='cuda')
            model.load_state_dict(checkpoint['state_dict'])

        device = torch.device('cuda')
        model.to(device)
        model.train()
        losses = AverageMeter()

        optimizer = optim.SGD(model.parameters(), lr=0.03, momentum=0.9)

        pbar = tqdm(enumerate(train_loader), total=len(train_loader))

        for i, (images, labels) in pbar:
            images = images.float()
            images = Variable(images.cuda())

            # images = images.unsqueeze(dim=1)

            # compute output
            optimizer.zero_grad()
            if epoch == 0:
                images = nn.init.normal_(images, mean=0, std=0.01)

            outputs = model(images)

            labels = np.squeeze(labels)
            labels = Variable(labels.cuda())
            labels = labels.long()

            outputs = np.squeeze(outputs)
            outputs = torch.nn.functional.softmax(outputs)

            loss = torch.nn.functional.cross_entropy(outputs, labels)
            losses.update(loss.data, images.size(0))

            loss.backward()
            optimizer.step()

            pbar.set_description(
                '[TRAIN] - EPOCH %d/ %d - BATCH LOSS: %.4f/ %.4f(avg) ' %
                (epoch + 1, num_epochs, losses.val, losses.avg))
        return losses.avg
    def train_ss1(self, train_balanced, valid_balanced):
        train_loader = create_dataloader.convert2dataloader(
        ).create_dset_screening_stage1(train_balanced)

        val_loader = create_dataloader.convert2dataloader(
        ).create_dset_screening_stage1(valid_balanced)

        model = screening.screening().cuda()

        optimizer = optim.SGD(model.parameters(), lr=0.03, momentum=0.9)

        best_loss = 0

        num_epochs = 200

        plt_loss = []
        plt_loss.append(['Current epoch', 'Train loss', 'Validation loss'])

        for epoch in range(0, num_epochs):

            # train for one epoch
            curr_loss = self.train(train_loader, model, epoch, num_epochs)
            curr_loss = curr_loss.item()

            if ((epoch + 1) % 1 == 0):
                self.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'best_prec1': best_loss,
                        'optimizer': optimizer.state_dict(),
                    }, epoch + 1)

                val_loss = self.validate(val_loader, epoch + 1)
                val_loss = val_loss.item()

                print(epoch + 1, curr_loss, val_loss)
                plt_loss.append([epoch + 1, curr_loss, val_loss])

        with open('screening_stage1_balanced', 'w') as f:
            for item in plt_loss:
                f.write("%s\n" % item)
示例#4
0
    def call_screening_model(self, test_dataloader, checkpoint_ss2):
        model_fcn = screening.screening().cuda()
        model_fcn.eval()

        state = torch.load(checkpoint_ss2)
        # load params
        model_fcn.load_state_dict(state['state_dict'], strict=False)

        pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))

        for index, patch in pbar:
            patch = patch.float()
            patch = Variable(patch.cuda())

            pred_candidate_score = model_fcn(patch)

            pred_candidate_score = torch.nn.functional.softmax(
                pred_candidate_score)

        return pred_candidate_score
    def call_train_ss2(self, complete_data, checkpoint_ss1):

        train_loader, val_loader = cmb_dataloader().create_train_val_dset(
            complete_data)

        model = screening.screening().cuda()

        optimizer = optim.SGD(model.parameters(), lr=0.03, momentum=0.9)

        best_loss = 0

        plt_loss = []
        plt_loss.append(['Current epoch', 'Train loss', 'Validation loss'])

        for epoch in range(0, self.num_epochs):

            # train for one epoch
            curr_loss = self.train(train_loader, model, epoch, self.num_epochs,
                                   checkpoint_ss1)
            curr_loss = curr_loss.item()

            if ((epoch + 1) % 1 == 0):
                self.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'best_prec1': best_loss,
                        'optimizer': optimizer.state_dict(),
                    }, epoch + 1)

                val_loss = self.validate(val_loader, epoch + 1)
                val_loss = val_loss.item()

                print(epoch + 1, curr_loss, val_loss)
                plt_loss.append([epoch + 1, curr_loss, val_loss])

        with open('screening_stage2.txt', 'w') as f:
            for item in plt_loss:
                f.write("%s\n" % item)
示例#6
0
def prepare_datset_with_mimics(subjects, checkpoint):
    data = []
    for count, (x, y) in enumerate(subjects):
        x = itk.GetArrayFromImage(x)
        y = itk.GetArrayFromImage(y)

        x = np.array(x)
        y = np.array(y)

        if y.max() == 1.0:
            data.append([x, 1.0])
        else:
            data.append([x, 0.0])

    dataloader = create_dataloader.convert2dataloader(
    ).create_dset_screening_stage1(data)

    device = torch.device('cuda')
    model = screening.screening()
    model.to(device)
    model.eval()
    state = torch.load(checkpoint, map_location='cuda')

    # load params
    model.load_state_dict(state['state_dict'])

    false_positive = []
    positive = []

    # set a progress bar
    pbar = tqdm(enumerate(dataloader), total=len(dataloader))

    total = 0

    for i, (images, labels) in pbar:
        images = images.float()
        images = Variable(images.cuda())

        images = images.unsqueeze(dim=1)

        outputs = model(images)

        labels = np.squeeze(labels)
        labels = Variable(labels.cuda())
        labels = labels.long()

        outputs = np.squeeze(outputs)
        outputs = torch.nn.functional.softmax(outputs)

        for x in range(0, len(images)):
            total += 1
            max_op = max(outputs[x])
            if (max_op == outputs[x][0]):
                pred = 0.0
            else:
                pred = 1.0

            if (labels[x] == 0.0) & (pred == 1.0):
                false_positive.append([images[x], 0.0])

            if (labels[x] == 1.0) & (pred == 1.0):
                positive.append([images[x], 1.0])

    random.shuffle(positive)
    random.shuffle(false_positive)

    complete_dataset_discrimination = positive + false_positive

    random.shuffle(complete_dataset_discrimination)

    return complete_dataset_discrimination
def prepare_datset_with_mimics(train, checkpoint):

    dataloader = create_dataloader.convert2dataloader(
    ).create_dset_screening_stage1(train)

    device = torch.device('cuda')
    model = screening.screening()
    model.to(device)
    model.eval()
    state = torch.load(checkpoint, map_location='cuda')

    # load params
    model.load_state_dict(state['state_dict'])

    false_positive = []  # 28.85%
    positive = []  # 23.63 %
    negative = []  # 47.52 %
    false_negative = []

    # set a progress bar
    pbar = tqdm(enumerate(dataloader), total=len(dataloader))

    total = 0

    for i, (images, labels) in pbar:
        images = images.float()
        images = Variable(images.cuda())

        images = images.unsqueeze(dim=1)

        outputs = model(images)

        labels = np.squeeze(labels)
        labels = Variable(labels.cuda())
        labels = labels.long()

        outputs = np.squeeze(outputs)
        outputs = torch.nn.functional.softmax(outputs)

        for x in range(0, len(images)):
            total += 1
            max_op = max(outputs[x])
            if (max_op == outputs[x][0]):
                pred = 0.0
            else:
                pred = 1.0

            if (labels[x] == 0.0) & (pred == 1.0):
                false_positive.append([images[x], 0.0])

            if (labels[x] == 0.0) & (pred == 0.0):
                negative.append([images[x], 0.0])

            if (labels[x] == 1.0) & (pred == 1.0):
                positive.append([images[x], 1.0])

            if (labels[x] == 1.0) & (pred == 0.0):
                false_negative.append([images[x], 1.0])

    random.shuffle(negative)
    random.shuffle(positive)
    random.shuffle(false_positive)

    new_false_positive = remove_percentage(false_positive, 29)

    new_negative_list = remove_percentage(negative, 0.47)

    new_positive = remove_percentage(positive, 0.24)

    complete_dataset_stage2 = new_false_positive + new_negative_list + new_positive

    random.shuffle(complete_dataset_stage2)

    return complete_dataset_stage2