예제 #1
0
    def __call__(self, model, scanner):
        dataset = scanner.dataset

        # Test-time augmentation
        if self.test_aug:

            # For variance computation
            if self.variance:
                aug_out = dict()
                for k, v in scanner.outputs.data.items():
                    aug_out[k] = list()
            else:
                aug_out = None

            count = 0.0
            for aug in self.test_aug:
                # dec2bin
                rule = np.array([int(x) for x in bin(aug)[2:].zfill(4)])
                print("Test-time augmentation {}".format(rule))

                # Augment dataset.
                aug_dset = Dataset(spec=self.in_spec)
                for k, v in dataset.data.items():
                    aug_dset.add_data(k, fwd_utils.flip(v._data, rule=rule))

                # Forward scan
                aug_scanner = self.make_forward_scanner(aug_dset)
                outputs = self.forward(model, aug_scanner)

                # Accumulate.
                for k, v in scanner.outputs.data.items():
                    print("Accumulate to {}...".format(k))
                    output = outputs.get_data(k)

                    # Revert output.
                    dst = (1, 1, 1) if k == 'affinity' else None
                    reverted = fwd_utils.revert_flip(output,
                                                     rule=rule,
                                                     dst=dst)
                    v._data += reverted

                    # For variance computation
                    if self.variance:
                        aug_out[k].append(reverted)

                count += 1

            # Normalize.
            for k, v in scanner.outputs.data.items():
                print("Normalize {}...".format(k))
                if self.precomputed:
                    v._data[...] /= count
                else:
                    v._norm._data[...] = count

            return (scanner.outputs, aug_out)

        return (self.forward(model, scanner), None)
예제 #2
0
파일: aff.py 프로젝트: torms3/DeepEM
    def build_dataset(self, tag, data):
        img = data['img']
        seg = data['seg']
        loc = data['loc']
        msk = self.get_mask(data)

        # Create Dataset.
        dset = Dataset(tag=tag)
        dset.add_data(key='input', data=img)
        dset.add_data(key='affinity', data=seg)
        dset.add_mask(key='affinity_mask', data=msk, loc=loc)

        return dset
예제 #3
0
    def build_dataset(self, tag, data):
        img = data['img']
        mit = data['mit']
        loc = data['loc']
        msk = self.get_mask(data)

        # Create Dataset.
        dset = Dataset(tag=tag)
        dset.add_data(key='input', data=img)
        dset.add_data(key='mitochondria', data=mit)
        dset.add_mask(key='mitochondria_mask', data=msk)

        return dset
예제 #4
0
    def build_dataset(self, datadir, vol):
        img = read_h5(os.path.join(datadir, vol + "_img.h5"))
        soma = read_h5(os.path.join(datadir, vol + "_lbl.h5")).astype("float32")

        #Preprocessing
        img = (img / 255.).astype("float32")
        soma[soma != 0] = 1

        # Create Dataset.
        dset = Dataset()
        dset.add_data(key='input', data=img)
        dset.add_data(key='soma_label', data=soma)
        return dset
예제 #5
0
    def build_dataset(self, datadir, vol):
        img = read_h5(os.path.join(datadir, vol + "_img.h5"))
        clf = read_h5(os.path.join(datadir, vol + "_syn.h5")).astype("float32")

        #Preprocessing
        img = (img / 255.).astype("float32")
        clf[clf != 0] = 1

        # Create Dataset.
        dset = Dataset()
        dset.add_data(key='input', data=img)
        dset.add_data(key='cleft_label', data=clf)
        return dset
예제 #6
0
파일: psd.py 프로젝트: torms3/DeepEM
    def build_dataset(self, tag, data):
        img = data['img']
        psd = data['psd']
        psd_msk = data['psd_msk']
        loc = data['loc']
        msk = self.get_mask(data)

        # Create Dataset.
        dset = Dataset(tag=tag)
        dset.add_data(key='input', data=img)
        dset.add_data(key='synapse', data=psd)
        dset.add_mask(key='synapse_mask', data=psd_msk, loc=loc)

        return dset
예제 #7
0
파일: utils.py 프로젝트: seung-lab/DeepEM
def make_forward_scanner(opt, data_name=None):
    # Cloud-volume
    if opt.gs_input:
        try:
            from deepem.test import cv_utils
            img = cv_utils.cutout(opt, opt.gs_input, dtype='uint8')

            # Optional input histogram normalization 
            if opt.gs_input_norm:
                assert len(opt.gs_input_norm) == 2
                low, high = opt.gs_input_norm
                img = normalize_per_slice(img, lowerfract=low, upperfract=high)
            
            # [0, 255] -> [0.0, 1.0]
            img = (img/255.).astype('float32')

            # Optional input mask
            if opt.gs_input_mask:
                try:
                    msk = cv_utils.cutout(opt, opt.gs_input_mask, dtype='uint8')
                    img[msk > 0] = 0
                except:
                    raise

        except ImportError:
            raise
    else:
        assert data_name is not None
        print(data_name)
        # Read an EM image.
        if opt.dummy:
            img = np.random.rand(*opt.dummy_inputsz[-3:]).astype('float32')
        else:
            fpath = os.path.join(opt.data_dir, data_name, opt.input_name)
            img = emio.imread(fpath)
            img = (img/255.).astype('float32')

        # Border mirroring
        if opt.mirror:
            pad_width = [(x//2,x//2) for x in opt.mirror]
            img = np.pad(img, pad_width, 'reflect')

    # ForwardScanner
    dataset = Dataset(spec=opt.in_spec)
    dataset.add_data('input', img)
    return ForwardScanner(dataset, opt.scan_spec, **opt.scan_params)
예제 #8
0
    def build_dataset(self, datadir, vol):
        
        # Reading either hdf5 or tif training data; raw image has to be consistent with label image
        if os.path.isfile(os.path.join(datadir, vol + "_img.h5")):
            img = read_img(os.path.join(datadir, vol + "_img.h5"))
            soma = read_img(os.path.join(datadir, vol + "_lbl.h5")).astype("float32")
        elif os.path.isfile(os.path.join(datadir, vol + "_img.tif")):
            img = read_img(os.path.join(datadir, vol + "_img.tif"))
            soma = read_img(os.path.join(datadir, vol + "_lbl.tif")).astype("float32")
        
        #Preprocessing
        img = (img / 255.).astype("float32")
        soma[soma != 0] = 1

        # Create Dataset.
        dset = Dataset()
        dset.add_data(key='input', data=img)
        dset.add_data(key='soma_label', data=soma)
        return dset
예제 #9
0
    def build_dataset(self, tag, data):
        img = data['img']
        seg = data['seg']
        syn = data['syn']
        mye = data['mye']
        blv = data['blv']
        loc = data['loc']
        msk = self.get_mask(data)

        # Create Dataset.
        dset = Dataset(tag=tag)
        dset.add_data(key='input', data=img)
        dset.add_data(key='affinity', data=seg)
        dset.add_mask(key='affinity_mask', data=msk, loc=loc)
        dset.add_data(key='synapse', data=syn)
        dset.add_mask(key='synapse_mask', data=msk)
        dset.add_data(key='myelin', data=mye)
        dset.add_mask(key='myelin_mask', data=msk)
        dset.add_data(key='blood_vessel', data=blv)
        dset.add_mask(key='blood_vessel_mask', data=msk)

        return dset
예제 #10
0
    def build_dataset(self, tag, data):
        img = data['img']
        seg = data['seg']
        psd = data['psd']
        psd_msk = data['psd_msk']
        mye = data['mye']
        loc = data['loc']
        msk = self.get_mask(data)

        # Create Dataset.
        dset = Dataset(tag=tag)
        dset.add_data(key='input', data=img)
        dset.add_data(key='affinity', data=seg)
        dset.add_mask(key='affinity_mask', data=msk, loc=loc)
        dset.add_data(key='synapse', data=psd)
        dset.add_mask(key='synapse_mask', data=psd_msk)
        dset.add_data(key='myelin', data=mye)
        dset.add_mask(key='myelin_mask', data=msk)

        return dset
예제 #11
0
파일: aff_glia.py 프로젝트: torms3/DeepEM
    def build_dataset(self, tag, data):
        img = data['img']
        seg = data['seg']
        loc = data['loc']
        msk = self.get_mask(data)
        glia = data['glia']
        gmsk = data['gmsk'] if 'gmsk' in data else msk

        # Create Dataset.
        dset = Dataset(tag=tag)
        dset.add_data(key='input', data=img)
        dset.add_data(key='affinity', data=seg)
        dset.add_mask(key='affinity_mask', data=msk, loc=loc)
        dset.add_data(key='glia', data=glia)
        dset.add_mask(key='glia_mask', data=gmsk)

        return dset