Beispiel #1
0
    def __init__(self, root, split='train', mode=None):
        super(Promise12, self).__init__(root, split=split, mode=mode)
        self.mode = mode
        #self.joint_transform = joint_transform
        root = root + '/' + self.BASE_DIR
        self.joint_transform = Compose([
            RandomTranslate(offset=(0.2, 0.1)),
            RandomVerticallyFlip(),
            RandomHorizontallyFlip(),
            RandomElasticTransform(alpha = 1.5, sigma = 0.07, img_type='F'),
            ])

        self.img_normalize = None

        # SECOND
        # store data in the npy file
        data_path = os.path.join(root, 'npy_image')
        if not os.path.exists(data_path):
            create_exp_dir(data_path, 'Create augmentation data at {}')
            data_to_array(root, data_path, self.CROP_SIZE, self.CROP_SIZE)
        else:
            print('read the data from: {}'.format(data_path))

        self.test_file_list = get_test_list(root)

        # read the data from npy
        if mode == 'train':
            self.X_train, self.y_train = load_train_data(data_path)
            self.size = self.X_train.shape[0]
        elif mode == 'val':
            self.X_val, self.y_val = load_val_data(data_path)
            self.size = self.X_val.shape[0]
        elif mode == 'test':
            self.X_test, self.x_slice_array = load_test_data(data_path)
            self.size = self.X_test.shape[0]
Beispiel #2
0
def main(format='svg'):

    genotype_name = 'NAS_UNET_NEW_V2'
    if len(sys.argv) != 2:
        print('usage:\n python {} ARCH_NAME, Default: NAS_UNET_V2'.format(
            sys.argv[0]))
    else:
        genotype_name = sys.argv[1]

    store_path = './cell_visualize/' + '/{}'.format(format) + '/{}'.format(
        genotype_name)
    create_exp_dir(store_path)

    if 'Windows' in platform.platform():
        os.environ['PATH'] += os.pathsep + '../3rd_tools/graphviz-2.38/bin/'
    try:
        genotype = eval('geno_types.{}'.format(genotype_name))
    except AttributeError:
        print('{} is not specified in geno_types.py'.format(genotype_name))
        sys.exit(1)

    visualize.plot(genotype.down, store_path + '/DownC', format=format)
    visualize.plot(genotype.up, store_path + '/UpC', format=format)
Beispiel #3
0
    res = [[s + 1, l + 1] for s, l in zip(list(start), list(length))]
    res = list(chain.from_iterable(res))
    return ' '.join([str(r) for r in res])


if __name__ == '__main__':
    input_path = '../../../predictions/nerve_rst'
    masks = [f for f in os.listdir(input_path) if f.endswith('_mask.tif')]
    masks = sorted(masks, key=lambda d: int(d.split('_')[0]))

    encodings = []
    total = len(masks)
    for m in masks:
        img = Image.open(os.path.join(input_path, m))
        x = np.array(img.getdata(), dtype=np.uint8).reshape(img.size[::-1])
        x = x // 255
        print('----->: processing: {}'.format(m))
        encodings.append(rle_encoding(x))
    print('Encode done, write to submission file')
    # check output
    first_row = 'img,pixels'
    create_exp_dir('./submission', '=> create submission file')
    file_name = os.path.join('submission', 'submission.csv')
    with open(file_name, 'w+') as f:
        f.write(first_row + '\n')
        for i in range(total):
            s = str(i + 1) + ',' + encodings[i]
            f.write(s + '\n')
    f.close()
    print('write done!')
Beispiel #4
0
    def test(self, img_queue, split='val', desc=''):
        self.model.eval()
        predict_list = []
        accuracy = 0
        tbar = tqdm(img_queue)
        create_exp_dir(desc, desc='=>Save prediction image on')
        with torch.no_grad():
            for step, (input, target) in enumerate(tbar):
                input = input.cuda(self.device)
                if not isinstance(target, list):
                    target = target.cuda(self.device)

                predicts = self.model(input)

                # for cityscapes, voc, camvid, test have label
                if not isinstance(target, list):
                    test_loss = self.criterion(predicts[0], target)
                    self.loss_meter.update(test_loss.item())
                    self.metric.update(target, predicts[0])
                    if step % self.cfg['training']['report_freq'] == 0:
                        pixAcc, mIoU = self.metric.get()
                        self.logger.info(
                            '{} loss: {}, pixAcc: {}, mIoU: {}'.format(
                                split, self.loss_meter.mloss, pixAcc, mIoU))
                        tbar.set_description(
                            'loss: %.6f, pixAcc: %.3f, mIoU: %.6f' %
                            (self.loss_meter.mloss, pixAcc, mIoU))
                    accuracy += dice_coefficient(predicts[0].cpu(),
                                                 target.cpu())
                else:
                    N = predicts[0].shape[0]
                    for i in range(N):
                        if self.args.crf:  # use crf
                            predict = torch.argmax(predicts[0].cpu(), 1)[i]
                            predict = dense_crf(
                                np.array(input[i].cpu()).astype(np.uint8),
                                predict) > 0.5
                            img = Image.fromarray(
                                (predict * 255).astype(np.uint8))
                            file_name = os.path.split(target[i])[1]
                            file_name = file_name.split('.')[0] + '_mask.tif'
                            img.save(os.path.join(desc, file_name))
                        else:
                            img = Image.fromarray(
                                (torch.argmax(predicts[0].cpu(), 1)[i] *
                                 255).numpy().astype(np.uint8))
                            file_name = os.path.split(target[i])[1]
                            file_name = file_name.split('.')[0] + '_mask.tif'
                            img.save(os.path.join(desc, file_name))

                if desc == 'promise12':  # for promise12, test have not label or have label to calc extra metric
                    predict_list += [
                        torch.argmax(predicts[0], dim=1).cpu().numpy()
                    ]

        print('==> accuracy: {}'.format(accuracy / len(img_queue)))

        # cause the predicts is a list [pred, aux_pred(may not)]
        if len(predicts[0].shape) == 4:  #
            pred = predicts[0]
        else:
            pred = predicts

        # save images
        if not isinstance(target, list) and not isinstance(target, str):  #
            grid_image = store_images(input, pred, target)
            pixAcc, mIoU = self.metric.get()
            self.logger.info('{}/loss: {}, pixAcc: {}, mIoU: {}'.format(
                split, self.loss_meter.mloss, pixAcc, mIoU))
        elif desc == 'promise12':  # for promise12, test have not label
            predict_test(predict_list, target,
                         self.save_path + '/{}_rst'.format(split))

        # for promise12 metirc
        if desc == 'promise12' and split == 'val':
            val_list = [5, 15, 25, 35, 45]
            dir = os.path.join(self.trainset.root_dir, self.trainset.base_dir,
                               'TrainingData')
            biomedical_image_metric(predict_list, val_list, dir + '/')