Exemplo n.º 1
0
    def compute_loss(self,z,y_, **kwargs):
        islices = kwargs.pop('islices',None)
        if islices is not None:
            mask = [self.immask[islices].ravel() for i in range(len(self.labelset))]
        else:
            mask = self.mask

        if np.sum(y_<-1e-6) > 0:
            miny = np.min(y_)
            logger.warning('negative (<-1e-6) values in y_. min = {:.3}'.format(miny))

        #self.use_ideal_loss = True

        #if self.use_ideal_loss:
        if self.loss_type in ['ideal', 'none']:
            loss = loss_functions.ideal_loss(z,y_,mask=mask)
        elif self.loss_type=='squareddiff':
            loss = loss_functions.anchor_loss(z,y_,mask=mask)
        elif self.loss_type=='laplacian':
            loss = loss_functions.laplacian_loss(z,y_,mask=mask)
        elif self.loss_type=='linear':
            loss = loss_function.linear_loss(z,y,mask=mask)
        else:
           raise Exception('wrong loss type')
           sys.exit(1)
        return loss*self.loss_factor
Exemplo n.º 2
0
    def compute_loss(self, z, y_, **kwargs):
        islices = kwargs.pop('islices', None)
        if islices is not None:
            mask = [
                self.immask[islices].ravel() for i in range(len(self.labelset))
            ]
        else:
            mask = self.mask

        if np.sum(y_ < -1e-6) > 0:
            miny = np.min(y_)
            logger.warning(
                'negative (<-1e-6) values in y_. min = {:.3}'.format(miny))

        #self.use_ideal_loss = True

        #if self.use_ideal_loss:
        if self.loss_type in ['ideal', 'none']:
            loss = loss_functions.ideal_loss(z, y_, mask=mask)
        elif self.loss_type == 'squareddiff':
            loss = loss_functions.anchor_loss(z, y_, mask=mask)
        elif self.loss_type == 'laplacian':
            loss = loss_functions.laplacian_loss(z, y_, mask=mask)
        elif self.loss_type == 'linear':
            loss = loss_function.linear_loss(z, y, mask=mask)
        else:
            raise Exception('wrong loss type')
            sys.exit(1)
        return loss * self.loss_factor
def compute_losses(z,y,mask):
    ## loss 0 : 1 - Dice(y,z)
    loss0 = loss_functions.ideal_loss(z,y,mask=mask)
    logger.info('Tloss = {}'.format(loss0))
    
    ## loss2: squared difference with ztilde
    loss1 = loss_functions.anchor_loss(z,y,mask=mask)
    logger.info('SDloss = {}'.format(loss1))
    
    ## loss3: laplacian loss
    loss2 = loss_functions.laplacian_loss(z,y,mask=mask)
    logger.info('LAPloss = {}'.format(loss2))

    ## loss4: linear loss
    loss3 = loss_functions.linear_loss(z,y,mask=mask)
    logger.info('LINloss = {}'.format(loss3))
    
    return loss0, loss1, loss2, loss3
Exemplo n.º 4
0
def compute_losses(z, y, mask):
    ## loss 0 : 1 - Dice(y,z)
    loss0 = loss_functions.ideal_loss(z, y, mask=mask)
    logger.info('Tloss = {}'.format(loss0))

    ## loss2: squared difference with ztilde
    loss1 = loss_functions.anchor_loss(z, y, mask=mask)
    logger.info('SDloss = {}'.format(loss1))

    ## loss3: laplacian loss
    loss2 = loss_functions.laplacian_loss(z, y, mask=mask)
    logger.info('LAPloss = {}'.format(loss2))

    ## loss4: linear loss
    loss3 = loss_functions.linear_loss(z, y, mask=mask)
    logger.info('LINloss = {}'.format(loss3))

    return loss0, loss1, loss2, loss3
Exemplo n.º 5
0
    def run_svm_inference(self,test,w, test_dir):
        logger.info('running inference on: {}'.format(test))
        
        ## normalize w
        # w = w / np.sqrt(np.dot(w,w))
        strw = ' '.join('{:.3}'.format(val) for val in np.asarray(w)*self.psi_scale)
        logger.debug('scaled w=[{}]'.format(strw))
    
        weights_laplacians = np.asarray(w)[self.indices_laplacians]
        weights_laplacians_h = np.asarray(self.hand_tuned_w)[self.indices_laplacians]
        weights_priors = np.asarray(w)[self.indices_priors]
        weights_priors_h = np.asarray(self.hand_tuned_w)[self.indices_priors]
    
        ## segment test image with trained w
        '''
        def meta_weight_functions(im,i,j,_w):    
            data = 0
            for iwf,wf in enumerate(self.laplacian_functions):
                _data = wf(im,i,j)
                data += _w[iwf]*_data
            return data
        weight_function = lambda im: meta_weight_functions(im,i,j,weights_laplacians)
        weight_function_h = lambda im: meta_weight_functions(im,i,j,weights_laplacians_h)
        '''
        weight_function = MetaLaplacianFunction(
            weights_laplacians,
            self.laplacian_functions)
        
        weight_function_h = MetaLaplacianFunction(
            weights_laplacians_h,
            self.laplacian_functions)
        
        ## load images and ground truth
        file_seg = self.dir_reg + test + 'seg.hdr'
        file_im  = self.dir_reg + test + 'gray.hdr'
        im  = io_analyze.load(file_im)
        seg = io_analyze.load(file_seg)
        seg.flat[~np.in1d(seg.ravel(),self.labelset)] = self.labelset[0]
        
        nim = im/np.std(im) # normalize image by std

        ## test training data ?
        inference_train = True
        if inference_train:
            train_ims, train_segs, train_metas = self.training_set
            for tim, tz, tmeta in zip(train_ims, train_segs, train_metas):
                ## retrieve metadata
                islices = tmeta.pop('islices',None)
                imask = tmeta.pop('imask', None)
                iimask = tmeta.pop('iimask',None)
                if islices is not None:
                    tseeds = self.seeds[islices]
                    tprior = {
                        'data': np.asarray(self.prior['data'])[:,iimask],
                        'imask': imask,
                        'variance': np.asarray(self.prior['variance'])[:,iimask],
                        'labelset': self.labelset,
                        }
                    if 'intensity' in self.prior: 
                        tprior['intensity'] = self.prior['intensity']
                else:
                    tseeds = self.seeds
                    tprior = self.prior

                ## prior
                tseg = self.labelset[np.argmax(tz, axis=0)].reshape(tim.shape)
                tanchor_api = MetaAnchor(
                    tprior,
                    self.prior_functions,
                    weights_priors,
                    image=tim,
                    )
                tsol,ty = rwsegment.segment(
                    tim, 
                    tanchor_api, 
                    seeds=tseeds,
                    weight_function=weight_function,
                    **self.rwparams_inf
                    )
                ## compute Dice coefficient
                tdice = compute_dice_coef(tsol, tseg, labelset=self.labelset)
                logger.info('Dice coefficients for train: \n{}'.format(tdice))
                nlabel = len(self.labelset)
                tflatmask = np.zeros(ty.shape, dtype=bool)
                tflatmask[:,imask] = True
                loss0 = loss_functions.ideal_loss(tz,ty,mask=tflatmask)
                logger.info('Tloss = {}'.format(loss0))
                ## loss2: squared difference with ztilde
                loss1 = loss_functions.anchor_loss(tz,ty,mask=tflatmask)
                logger.info('SDloss = {}'.format(loss1))
                ## loss3: laplacian loss
                loss2 = loss_functions.laplacian_loss(tz,ty,mask=tflatmask)
                logger.info('LAPloss = {}'.format(loss2))


                tanchor_api_h = MetaAnchor(
                    tprior,
                    self.prior_functions,
                    weights_priors_h,
                    image=tim,
                    )
            
                tsol,ty = rwsegment.segment(
                    tim, 
                    tanchor_api_h, 
                    seeds=tseeds,
                    weight_function=weight_function_h,
                    **self.rwparams_inf
                    )
                ## compute Dice coefficient
                tdice = compute_dice_coef(tsol, tseg, labelset=self.labelset)
                logger.info('Dice coefficients for train (hand-tuned): \n{}'.format(tdice))
                loss0 = loss_functions.ideal_loss(tz,ty,mask=tflatmask)
                logger.info('Tloss (hand-tuned) = {}'.format(loss0))
                ## loss2: squared difference with ztilde
                loss1 = loss_functions.anchor_loss(tz,ty,mask=tflatmask)
                logger.info('SDloss (hand-tuned) = {}'.format(loss1))
                ## loss3: laplacian loss
                loss2 = loss_functions.laplacian_loss(tz,ty,mask=tflatmask)
                logger.info('LAPloss (hand-tuned) = {}'.format(loss2))
                break
 
        ## prior
        anchor_api = MetaAnchor(
            self.prior,
            self.prior_functions,
            weights_priors,
            image=nim,
            )
    
        sol,y = rwsegment.segment(
            nim, 
            anchor_api, 
            seeds=self.seeds,
            weight_function=weight_function,
            **self.rwparams_inf
            )
        
        ## compute Dice coefficient
        dice = compute_dice_coef(sol, seg,labelset=self.labelset)
        logger.info('Dice coefficients: \n{}'.format(dice))

        ## objective
        en_rw = rwsegment.energy_rw(
            nim, y, seeds=self.seeds,weight_function=weight_function, **self.rwparams_inf)
        en_anchor = rwsegment.energy_anchor(
            nim, y, anchor_api, seeds=self.seeds, **self.rwparams_inf)
        obj = en_rw + en_anchor
        logger.info('Objective = {:.3}'.format(obj))

        
        ## compute losses
        z = seg.ravel()==np.c_[self.labelset]
        mask = self.seeds < 0
        flatmask = mask.ravel()*np.ones((len(self.labelset),1))
        
        ## loss 0 : 1 - Dice(y,z)
        loss0 = loss_functions.ideal_loss(z,y,mask=flatmask)
        logger.info('Tloss = {}'.format(loss0))
        
        ## loss2: squared difference with ztilde
        loss1 = loss_functions.anchor_loss(z,y,mask=flatmask)
        logger.info('SDloss = {}'.format(loss1))
        
        ## loss3: laplacian loss
        loss2 = loss_functions.laplacian_loss(z,y,mask=flatmask)
        logger.info('LAPloss = {}'.format(loss2))

        ## loss4: linear loss
        loss3 = loss_functions.linear_loss(z,y,mask=flatmask)
        logger.info('LINloss = {}'.format(loss3))
       
        ## saving
        if self.debug:
            pass
        elif self.isroot:
            outdir = self.dir_inf + test_dir
            logger.info('saving data in: {}'.format(outdir))
            if not os.path.isdir(outdir):
                os.makedirs(outdir)
                
            #io_analyze.save(outdir + 'im.hdr',im.astype(np.int32))
            #np.save(outdir + 'y.npy',y)        
            #io_analyze.save(outdir + 'sol.hdr',sol.astype(np.int32))
            np.savetxt(outdir + 'objective.txt', [obj])
            np.savetxt(
                outdir + 'dice.txt', 
                np.c_[dice.keys(),dice.values()],fmt='%d %f')
        
            f = open(outdir + 'losses.txt', 'w')
            f.write('ideal_loss\t{}\n'.format(loss0))
            f.write('anchor_loss\t{}\n'.format(loss1))
            f.write('laplacian_loss\t{}\n'.format(loss2))
            f.close()
Exemplo n.º 6
0
    def compute_approximate_aci(self, w,x,z,y0,**kwargs):
        logger.info("using approximate aci (Danny's)")
        islices = kwargs.pop('islices',None)
        imask = kwargs.pop('imask',None)
        iimask = kwargs.pop('iimask',None)
        if islices is not None:
            seeds = self.seeds[islices]
            mask = [self.immask[islices].ravel() for i in range(len(self.labelset))]
            prior = {
                'data': np.asarray(self.prior['data'])[:,iimask],
                'imask': imask,
                'variance': np.asarray(self.prior['variance'])[:,iimask],
                'labelset': self.labelset,
                }
            if 'intensity' in self.prior: prior['intensity'] = self.prior['intensity']
        else:
            mask = self.mask
            seeds = self.seeds
            prior = self.prior
        
        weight_function = MetaLaplacianFunction(
            np.asarray(w)[self.indices_laplacians],
            self.laplacian_functions,
            )
        
        ## combine all prior models
        anchor_api = MetaAnchor(
            prior=prior,
            prior_models=self.prior_models,
            prior_weights=np.asarray(w)[self.indices_priors],
            image=x,
            )

        class GroundTruthAnchor(object):
            def __init__(self, anchor_api, gt, gt_weights):
                self.anchor_api = anchor_api
                self.gt = gt
                self.gt_weights = gt_weights
            def get_labelset(self): 
                return self.anchor_api.get_labelset()

            def get_anchor_and_weights(self, D, indices):
                anchor, weights = self.anchor_api.get_anchor_and_weights(D,indices)
                gt_weights = self.gt_weights[:,indices]
                gt = self.gt[:,indices]
                new_weights = weights + gt_weights
                new_anchor = (anchor * weights + gt*gt_weights) / new_weights
                return new_anchor, new_weights
                
        self.approx_aci_maxiter = 200
        self.approx_aci_maxstep = 1e-2
        z_weights = np.zeros(np.asarray(z).shape)
        z_label = np.argmax(z,axis=0)
        for i in range(self.approx_aci_maxiter):
            logger.debug("approx aci, iter={}".format(i))
    
            ## add ground truth to anchor api
            modified_api = GroundTruthAnchor(anchor_api, z, z_weights)

            ## inference
            y_ = rwsegment.segment(
                x, 
                modified_api,
                seeds=seeds,
                weight_function=weight_function,
                return_arguments=['y'],
                **self.rwparams
                )

            ## loss            
            #loss = self.compute_loss(z,y_, islices=islices)
            loss = loss_functions.ideal_loss(z,y_,mask=mask)
            logger.debug('loss = {}'.format(loss))
            if loss < 1e-8: 
                break
            
            ## uptade weights
            delta = np.max(y_ - y_[z_label, np.arange(y_.shape[1])], axis=0)
            delta = np.clip(delta, 0, self.approx_aci_maxstep)
            z_weights += delta

        return y_        
    def run_svm_inference(self, test, w, test_dir):
        logger.info('running inference on: {}'.format(test))

        ## normalize w
        # w = w / np.sqrt(np.dot(w,w))
        strw = ' '.join('{:.3}'.format(val)
                        for val in np.asarray(w) * self.psi_scale)
        logger.debug('scaled w=[{}]'.format(strw))

        weights_laplacians = np.asarray(w)[self.indices_laplacians]
        weights_laplacians_h = np.asarray(
            self.hand_tuned_w)[self.indices_laplacians]
        weights_priors = np.asarray(w)[self.indices_priors]
        weights_priors_h = np.asarray(self.hand_tuned_w)[self.indices_priors]

        ## segment test image with trained w
        '''
        def meta_weight_functions(im,i,j,_w):    
            data = 0
            for iwf,wf in enumerate(self.laplacian_functions):
                _data = wf(im,i,j)
                data += _w[iwf]*_data
            return data
        weight_function = lambda im: meta_weight_functions(im,i,j,weights_laplacians)
        weight_function_h = lambda im: meta_weight_functions(im,i,j,weights_laplacians_h)
        '''
        weight_function = MetaLaplacianFunction(weights_laplacians,
                                                self.laplacian_functions)

        weight_function_h = MetaLaplacianFunction(weights_laplacians_h,
                                                  self.laplacian_functions)

        ## load images and ground truth
        file_seg = self.dir_reg + test + 'seg.hdr'
        file_im = self.dir_reg + test + 'gray.hdr'
        im = io_analyze.load(file_im)
        seg = io_analyze.load(file_seg)
        seg.flat[~np.in1d(seg.ravel(), self.labelset)] = self.labelset[0]

        nim = im / np.std(im)  # normalize image by std

        ## test training data ?
        inference_train = True
        if inference_train:
            train_ims, train_segs, train_metas = self.training_set
            for tim, tz, tmeta in zip(train_ims, train_segs, train_metas):
                ## retrieve metadata
                islices = tmeta.pop('islices', None)
                imask = tmeta.pop('imask', None)
                iimask = tmeta.pop('iimask', None)
                if islices is not None:
                    tseeds = self.seeds[islices]
                    tprior = {
                        'data': np.asarray(self.prior['data'])[:, iimask],
                        'imask': imask,
                        'variance': np.asarray(self.prior['variance'])[:,
                                                                       iimask],
                        'labelset': self.labelset,
                    }
                    if 'intensity' in self.prior:
                        tprior['intensity'] = self.prior['intensity']
                else:
                    tseeds = self.seeds
                    tprior = self.prior

                ## prior
                tseg = self.labelset[np.argmax(tz, axis=0)].reshape(tim.shape)
                tanchor_api = MetaAnchor(
                    tprior,
                    self.prior_functions,
                    weights_priors,
                    image=tim,
                )
                tsol, ty = rwsegment.segment(tim,
                                             tanchor_api,
                                             seeds=tseeds,
                                             weight_function=weight_function,
                                             **self.rwparams_inf)
                ## compute Dice coefficient
                tdice = compute_dice_coef(tsol, tseg, labelset=self.labelset)
                logger.info('Dice coefficients for train: \n{}'.format(tdice))
                nlabel = len(self.labelset)
                tflatmask = np.zeros(ty.shape, dtype=bool)
                tflatmask[:, imask] = True
                loss0 = loss_functions.ideal_loss(tz, ty, mask=tflatmask)
                logger.info('Tloss = {}'.format(loss0))
                ## loss2: squared difference with ztilde
                loss1 = loss_functions.anchor_loss(tz, ty, mask=tflatmask)
                logger.info('SDloss = {}'.format(loss1))
                ## loss3: laplacian loss
                loss2 = loss_functions.laplacian_loss(tz, ty, mask=tflatmask)
                logger.info('LAPloss = {}'.format(loss2))

                tanchor_api_h = MetaAnchor(
                    tprior,
                    self.prior_functions,
                    weights_priors_h,
                    image=tim,
                )

                tsol, ty = rwsegment.segment(tim,
                                             tanchor_api_h,
                                             seeds=tseeds,
                                             weight_function=weight_function_h,
                                             **self.rwparams_inf)
                ## compute Dice coefficient
                tdice = compute_dice_coef(tsol, tseg, labelset=self.labelset)
                logger.info(
                    'Dice coefficients for train (hand-tuned): \n{}'.format(
                        tdice))
                loss0 = loss_functions.ideal_loss(tz, ty, mask=tflatmask)
                logger.info('Tloss (hand-tuned) = {}'.format(loss0))
                ## loss2: squared difference with ztilde
                loss1 = loss_functions.anchor_loss(tz, ty, mask=tflatmask)
                logger.info('SDloss (hand-tuned) = {}'.format(loss1))
                ## loss3: laplacian loss
                loss2 = loss_functions.laplacian_loss(tz, ty, mask=tflatmask)
                logger.info('LAPloss (hand-tuned) = {}'.format(loss2))
                break

        ## prior
        anchor_api = MetaAnchor(
            self.prior,
            self.prior_functions,
            weights_priors,
            image=nim,
        )

        sol, y = rwsegment.segment(nim,
                                   anchor_api,
                                   seeds=self.seeds,
                                   weight_function=weight_function,
                                   **self.rwparams_inf)

        ## compute Dice coefficient
        dice = compute_dice_coef(sol, seg, labelset=self.labelset)
        logger.info('Dice coefficients: \n{}'.format(dice))

        ## objective
        en_rw = rwsegment.energy_rw(nim,
                                    y,
                                    seeds=self.seeds,
                                    weight_function=weight_function,
                                    **self.rwparams_inf)
        en_anchor = rwsegment.energy_anchor(nim,
                                            y,
                                            anchor_api,
                                            seeds=self.seeds,
                                            **self.rwparams_inf)
        obj = en_rw + en_anchor
        logger.info('Objective = {:.3}'.format(obj))

        ## compute losses
        z = seg.ravel() == np.c_[self.labelset]
        mask = self.seeds < 0
        flatmask = mask.ravel() * np.ones((len(self.labelset), 1))

        ## loss 0 : 1 - Dice(y,z)
        loss0 = loss_functions.ideal_loss(z, y, mask=flatmask)
        logger.info('Tloss = {}'.format(loss0))

        ## loss2: squared difference with ztilde
        loss1 = loss_functions.anchor_loss(z, y, mask=flatmask)
        logger.info('SDloss = {}'.format(loss1))

        ## loss3: laplacian loss
        loss2 = loss_functions.laplacian_loss(z, y, mask=flatmask)
        logger.info('LAPloss = {}'.format(loss2))

        ## loss4: linear loss
        loss3 = loss_functions.linear_loss(z, y, mask=flatmask)
        logger.info('LINloss = {}'.format(loss3))

        ## saving
        if self.debug:
            pass
        elif self.isroot:
            outdir = self.dir_inf + test_dir
            logger.info('saving data in: {}'.format(outdir))
            if not os.path.isdir(outdir):
                os.makedirs(outdir)

            #io_analyze.save(outdir + 'im.hdr',im.astype(np.int32))
            #np.save(outdir + 'y.npy',y)
            #io_analyze.save(outdir + 'sol.hdr',sol.astype(np.int32))
            np.savetxt(outdir + 'objective.txt', [obj])
            np.savetxt(outdir + 'dice.txt',
                       np.c_[dice.keys(), dice.values()],
                       fmt='%d %f')

            f = open(outdir + 'losses.txt', 'w')
            f.write('ideal_loss\t{}\n'.format(loss0))
            f.write('anchor_loss\t{}\n'.format(loss1))
            f.write('laplacian_loss\t{}\n'.format(loss2))
            f.close()
    def process_sample(self,test,fold=None):

        ## get prior
        prior, mask = load_or_compute_prior_and_mask(
            test,
            fold=fold,
            force_recompute=self.force_recompute_prior)
        seeds   = (-1)*mask
        
        ## load image
        file_name = config.dir_reg + test + 'gray.hdr'        
        logger.info('segmenting data: {}'.format(file_name))
        im      = io_analyze.load(file_name)
        file_gt = config.dir_reg + test + 'seg.hdr'
        seg     = io_analyze.load(file_gt)
        seg.flat[~np.in1d(seg, self.labelset)] = self.labelset[0]
        
           
        ## normalize image
        nim = im/np.std(im)
            
        ## init anchor_api
        anchor_api = MetaAnchor(
            prior=prior,
            prior_models=self.prior_models,
            prior_weights=self.prior_weights,
            image=nim,
            )
           
        ## start segmenting
        #import ipdb; ipdb.set_trace()
        sol,y = rwsegment.segment(
            nim, 
            anchor_api,
            seeds=seeds, 
            labelset=self.labelset, 
            weight_function=self.weight_function,
            **self.params
            )

        ## compute losses
        z = seg.ravel()==np.c_[self.labelset]
        flatmask = mask.ravel()*np.ones((len(self.labelset),1))
        
        ## loss 0 : 1 - Dice(y,z)
        loss0 = loss_functions.ideal_loss(z,y,mask=flatmask)
        logger.info('Tloss = {}'.format(loss0))
        
        ## loss2: squared difference with ztilde
        loss1 = loss_functions.anchor_loss(z,y,mask=flatmask)
        logger.info('SDloss = {}'.format(loss1))
        
        ## loss3: laplacian loss
        loss2 = loss_functions.laplacian_loss(z,y,mask=flatmask)
        logger.info('LAPloss = {}'.format(loss2))
 
        ## loss4: linear loss
        loss3 = loss_functions.linear_loss(z,y,mask=flatmask)
        logger.info('LINloss = {}'.format(loss3))
        
        ## compute Dice coefficient per label
        dice    = compute_dice_coef(sol, seg,labelset=self.labelset)
        logger.info('Dice: {}'.format(dice))
        
        if not config.debug:
            if fold is not None:
                test_name = 'f{}_{}'.format(fold[0][:2], test)
            else:
                test_name = test
            outdir = config.dir_seg + \
                '/{}/{}'.format(self.model_name,test_name)
            logger.info('saving data in: {}'.format(outdir))
            if not os.path.isdir(outdir):
                os.makedirs(outdir)
        
            f = open(outdir + 'losses.txt', 'w')
            f.write('ideal_loss\t{}\n'.format(loss0))
            f.write('anchor_loss\t{}\n'.format(loss1))
            f.write('laplacian_loss\t{}\n'.format(loss2))
            f.close()
            
            io_analyze.save(outdir + 'sol.hdr', sol.astype(np.int32)) 
            np.savetxt(
                outdir + 'dice.txt', np.c_[dice.keys(),dice.values()],fmt='%d %.8f')
    def compute_mean_segmentation(self, list):
        for test in list:
            file_gt = config.dir_reg + test + 'seg.hdr'
            seg     = io_analyze.load(file_gt)
            seg.flat[~np.in1d(seg, self.labelset)] = self.labelset[0]
           

            ## get prior
            prior, mask = load_or_compute_prior_and_mask(
                test,force_recompute=self.force_recompute_prior)
            mask = mask.astype(bool)            
           

            y = np.zeros((len(self.labelset),seg.size))
            y[:,0] = 1
            y.flat[prior['imask']] = prior['data']
 
            sol = np.zeros(seg.shape,dtype=np.int32)
            sol[mask] = self.labelset[np.argmax(prior['data'],axis=0)]

            ## compute losses
            z = seg.ravel()==np.c_[self.labelset]
            flatmask = mask.ravel()*np.ones((len(self.labelset),1))
 
            ## loss 0 : 1 - Dice(y,z)
            loss0 = loss_functions.ideal_loss(z,y,mask=flatmask)
            logger.info('Tloss = {}'.format(loss0))
            
            ## loss2: squared difference with ztilde
            #loss1 = loss_functions.anchor_loss(z,y,mask=flatmask)
            #logger.info('SDloss = {}'.format(loss1))
            
            ## loss3: laplacian loss
            #loss2 = loss_functions.laplacian_loss(z,y,mask=flatmask)
            #logger.info('LAPloss = {}'.format(loss2))
 
            ## loss4: linear loss
            #loss3 = loss_functions.linear_loss(z,y,mask=flatmask)
            #logger.info('LINloss = {}'.format(loss3))
            
            ## compute Dice coefficient per label
            dice    = compute_dice_coef(sol, seg,labelset=self.labelset)
            logger.info('Dice: {}'.format(dice))
            
            if not config.debug:
                outdir = config.dir_seg + \
                    '/{}/{}'.format('mean',test)
                logger.info('saving data in: {}'.format(outdir))
                if not os.path.isdir(outdir):
                    os.makedirs(outdir)
            
                #f = open(outdir + 'losses.txt', 'w')
                #f.write('ideal_loss\t{}\n'.format(loss0))
                #f.write('anchor_loss\t{}\n'.format(loss1))
                #f.write('laplacian_loss\t{}\n'.format(loss2))
                #f.close()
                
                io_analyze.save(outdir + 'sol.hdr', sol.astype(np.int32)) 

                np.savetxt(
                    outdir + 'dice.txt', np.c_[dice.keys(),dice.values()],fmt='%d %.8f')
Exemplo n.º 10
0
    def compute_approximate_aci(self, w, x, z, y0, **kwargs):
        logger.info("using approximate aci (Danny's)")
        islices = kwargs.pop('islices', None)
        imask = kwargs.pop('imask', None)
        iimask = kwargs.pop('iimask', None)
        if islices is not None:
            seeds = self.seeds[islices]
            mask = [
                self.immask[islices].ravel() for i in range(len(self.labelset))
            ]
            prior = {
                'data': np.asarray(self.prior['data'])[:, iimask],
                'imask': imask,
                'variance': np.asarray(self.prior['variance'])[:, iimask],
                'labelset': self.labelset,
            }
            if 'intensity' in self.prior:
                prior['intensity'] = self.prior['intensity']
        else:
            mask = self.mask
            seeds = self.seeds
            prior = self.prior

        weight_function = MetaLaplacianFunction(
            np.asarray(w)[self.indices_laplacians],
            self.laplacian_functions,
        )

        ## combine all prior models
        anchor_api = MetaAnchor(
            prior=prior,
            prior_models=self.prior_models,
            prior_weights=np.asarray(w)[self.indices_priors],
            image=x,
        )

        class GroundTruthAnchor(object):
            def __init__(self, anchor_api, gt, gt_weights):
                self.anchor_api = anchor_api
                self.gt = gt
                self.gt_weights = gt_weights

            def get_labelset(self):
                return self.anchor_api.get_labelset()

            def get_anchor_and_weights(self, D, indices):
                anchor, weights = self.anchor_api.get_anchor_and_weights(
                    D, indices)
                gt_weights = self.gt_weights[:, indices]
                gt = self.gt[:, indices]
                new_weights = weights + gt_weights
                new_anchor = (anchor * weights + gt * gt_weights) / new_weights
                return new_anchor, new_weights

        self.approx_aci_maxiter = 200
        self.approx_aci_maxstep = 1e-2
        z_weights = np.zeros(np.asarray(z).shape)
        z_label = np.argmax(z, axis=0)
        for i in range(self.approx_aci_maxiter):
            logger.debug("approx aci, iter={}".format(i))

            ## add ground truth to anchor api
            modified_api = GroundTruthAnchor(anchor_api, z, z_weights)

            ## inference
            y_ = rwsegment.segment(x,
                                   modified_api,
                                   seeds=seeds,
                                   weight_function=weight_function,
                                   return_arguments=['y'],
                                   **self.rwparams)

            ## loss
            #loss = self.compute_loss(z,y_, islices=islices)
            loss = loss_functions.ideal_loss(z, y_, mask=mask)
            logger.debug('loss = {}'.format(loss))
            if loss < 1e-8:
                break

            ## uptade weights
            delta = np.max(y_ - y_[z_label, np.arange(y_.shape[1])], axis=0)
            delta = np.clip(delta, 0, self.approx_aci_maxstep)
            z_weights += delta

        return y_