def create(
        self,
        arch,
        num_output_channels,
        num_input_channels,
        loss,
        lr,
        optimizer,
        lrsch,
        momentum=0.9,
        weight_decay=5e-4,
        pretrained=False,
        size_input=388,
        num_classes=8,
        backbone='preactresnet',
        num_filters=32,
        breal='real',
        alpha=2,
        beta=2,
    ):
        """
        Create
            -arch (string): architecture
            -loss (string):
            -lr (float): learning rate
            -optimizer (string) :
            -lrsch (string): scheduler learning rate
            -pretrained (bool)
        """
        cfg_opt = {'momentum': momentum, 'weight_decay': weight_decay}
        #cfg_scheduler={ 'step_size':100, 'gamma':0.1  }
        cfg_scheduler = {'mode': 'min', 'patience': 10}
        cfg_model = {'num_filters': num_filters}

        self.num_classes = num_classes

        super(ClassNeuralNet, self).create(
            arch,
            num_output_channels,
            num_input_channels,
            loss,
            lr,
            optimizer,
            lrsch,
            pretrained,
            cfg_opt=cfg_opt,
            cfg_scheduler=cfg_scheduler,
            cfg_model=cfg_model,
        )
        self.size_input = size_input
        self.backbone = backbone
        self.num_filters = num_filters

        self.topk = nloss.TopkAccuracy()

        self.logger_train = Logger('Train', ['loss', 'loss_bce'], ['topk'],
                                   self.plotter)
        self.logger_val = Logger('Val  ', ['loss', 'loss_bce'], ['topk'],
                                 self.plotter)
        self.breal = breal
    def create(
            self,
            arch,
            num_output_channels,
            num_input_channels,
            loss,
            lr,
            optimizer,
            lrsch,
            momentum=0.9,
            weight_decay=5e-4,
            pretrained=False,
            topk=(1, ),
            size_input=128,
    ):
        """
        Create
        Args:
            arch (string): architecture
            num_output_channels,
            num_input_channels,
            loss (string):
            lr (float): learning rate
            momentum,
            optimizer (string) :
            lrsch (string): scheduler learning rate
            pretrained (bool)
        """

        cfg_opt = {'momentum': 0.9, 'weight_decay': 5e-4}
        cfg_scheduler = {'step_size': 100, 'gamma': 0.1}

        super(NeuralNetClassifier, self).create(
            arch,
            num_output_channels,
            num_input_channels,
            loss,
            lr,
            optimizer,
            lrsch,
            pretrained,
            cfg_opt=cfg_opt,
            cfg_scheduler=cfg_scheduler,
        )

        self.size_input = size_input
        self.accuracy = nloss.TopkAccuracy(topk)
        self.cnf = nloss.ConfusionMeter(self.num_output_channels,
                                        normalized=True)
        self.visheatmap = gph.HeatMapVisdom(env_name=self.nameproject)

        # Set the graphic visualization
        self.metrics_name = ['top{}'.format(k) for k in topk]
        self.logger_train = Logger('Trn', ['loss'], self.metrics_name,
                                   self.plotter)
        self.logger_val = Logger('Val', ['loss'], self.metrics_name,
                                 self.plotter)
    def create(
        self,
        arch,
        num_output_channels,
        num_input_channels,
        loss,
        lr,
        optimizer,
        lrsch,
        momentum=0.9,
        weight_decay=5e-4,
        pretrained=False,
        th=0.0,
        size_input=32,
    ):
        """
        Create
        Args:
            arch (string): architecture
            num_output_channels, 
            num_input_channels,  
            loss (string):
            lr (float): learning rate
            momentum,
            optimizer (string) : 
            lrsch (string): scheduler learning rate
            pretrained (bool)
        """

        cfg_opt = {'momentum': 0.9, 'weight_decay': 5e-4}
        cfg_scheduler = {'step_size': 50, 'gamma': 0.1}

        super(NeuralNetClassifier, self).create(
            arch,
            num_output_channels,
            num_input_channels,
            loss,
            lr,
            optimizer,
            lrsch,
            pretrained,
            cfg_opt=cfg_opt,
            cfg_scheduler=cfg_scheduler,
        )

        self.size_input = size_input
        self.accuracy = nloss.MultAccuracyV1(th)
        self.f_score = nloss.F_score(threshold=th, beta=2)

        #self.cnf = nloss.ConfusionMeter( self.num_output_channels, normalized=True )
        #self.visheatmap = gph.HeatMapVisdom( env_name=self.nameproject )

        # Set the graphic visualization
        self.logger_train = Logger('Trn', ['loss'], ['acc', 'f1'],
                                   self.plotter)
        self.logger_val = Logger('Val', ['loss'], ['acc', 'f1'], self.plotter)
    def create(self, 
        arch, 
        num_output_channels,
        num_input_channels,        
        loss,
        lr,
        optimizer,
        lrsch,
        momentum=0.9,
        weight_decay=5e-4,
        pretrained=False,
        size_input=388,
        num_classes=8,
        ):
        """
        Create    
        Args:        
            -arch (string): architecture
            -num_output_channels,
            -num_input_channels,  
            -loss (string):
            -lr (float): learning rate
            -optimizer (string) : 
            -lrsch (string): scheduler learning rate
            -pretrained (bool)
            -
        """        
        super(AttentionGMMNeuralNet, self).create( 
            arch, 
            num_output_channels,
            num_input_channels,        
            loss,
            lr,
            optimizer,
            lrsch,
            momentum,
            weight_decay,
            pretrained,
            size_input,
            num_classes,          
        )

        self.logger_train = Logger( 'Train', ['loss', 'loss_gmm', 'loss_bce', 'loss_att' ], [ 'topk', 'gmm'], self.plotter  )
        self.logger_val   = Logger( 'Val  ', ['loss', 'loss_gmm', 'loss_bce', 'loss_att' ], [ 'topk', 'gmm'], self.plotter )
Exemple #5
0
    def create(
            self,
            arch,
            num_output_channels,
            num_input_channels,
            loss,
            lr,
            momentum,
            optimizer,
            lrsch,
            pretrained=False,
            topk=(1, ),
    ):
        """
        Create
        Args:
            @arch (string): architecture
            @num_output_channels, 
            @num_input_channels,  
            @loss (string):
            @lr (float): learning rate
            @momentum,
            @optimizer (string) : 
            @lrsch (string): scheduler learning rate
            @pretrained (bool)
        """
        super(NeuralNetClassifier,
              self).create(arch, num_output_channels, num_input_channels, loss,
                           lr, momentum, optimizer, lrsch, pretrained)
        self.accuracy = nloss.Accuracy(topk)
        self.cnf = nloss.ConfusionMeter(self.num_output_channels,
                                        normalized=True)
        self.visheatmap = gph.HeatMapVisdom(env_name=self.nameproject)

        # Set the graphic visualization
        self.metrics_name = ['top{}'.format(k) for k in topk]
        self.logger_train = Logger('Trn', ['loss'], self.metrics_name,
                                   self.plotter)
        self.logger_val = Logger('Val', ['loss'], self.metrics_name,
                                 self.plotter)
    def create(
        self,
        arch,
        num_output_channels,
        num_input_channels,
        loss,
        lr,
        momentum,
        optimizer,
        lrsch,
        pretrained=False,
        size_input=388,
    ):
        """
        Create            
            -arch (string): architecture
            -loss (string):
            -lr (float): learning rate
            -optimizer (string) : 
            -lrsch (string): scheduler learning rate
            -pretrained (bool)
        """
        super(SegmentationNeuralNet,
              self).create(arch, num_output_channels, num_input_channels, loss,
                           lr, momentum, optimizer, lrsch, pretrained)
        self.size_input = size_input

        self.accuracy = nloss.Accuracy()
        self.dice = nloss.Dice()

        # Set the graphic visualization
        self.logger_train = Logger('Train', ['loss'], ['accs', 'dices'],
                                   self.plotter)
        self.logger_val = Logger('Val  ', ['loss'], ['accs', 'dices'],
                                 self.plotter)

        self.visheatmap = gph.HeatMapVisdom(env_name=self.nameproject,
                                            heatsize=(100, 100))
        self.visimshow = gph.ImageVisdom(env_name=self.nameproject,
                                         imsize=(100, 100))
class NeuralNetClassifier(NeuralNetAbstract):
    r"""Convolutional Neural Net for classification
    Args:
        patchproject (str): path project
        nameproject (str):  name project
        no_cuda (bool): system cuda (default is True)
        parallel (bool)
        seed (int)
        print_freq (int)
        gpu (int)
    """
    def __init__(self,
                 patchproject,
                 nameproject,
                 no_cuda=True,
                 parallel=False,
                 seed=1,
                 print_freq=10,
                 gpu=0):
        super(NeuralNetClassifier,
              self).__init__(patchproject, nameproject, no_cuda, parallel,
                             seed, print_freq, gpu)

    def create(
            self,
            arch,
            num_output_channels,
            num_input_channels,
            loss,
            lr,
            optimizer,
            lrsch,
            momentum=0.9,
            weight_decay=5e-4,
            pretrained=False,
            topk=(1, ),
            size_input=128,
    ):
        """
        Create
        Args:
            arch (string): architecture
            num_output_channels,
            num_input_channels,
            loss (string):
            lr (float): learning rate
            momentum,
            optimizer (string) :
            lrsch (string): scheduler learning rate
            pretrained (bool)
        """

        cfg_opt = {'momentum': 0.9, 'weight_decay': 5e-4}
        cfg_scheduler = {'step_size': 100, 'gamma': 0.1}

        super(NeuralNetClassifier, self).create(
            arch,
            num_output_channels,
            num_input_channels,
            loss,
            lr,
            optimizer,
            lrsch,
            pretrained,
            cfg_opt=cfg_opt,
            cfg_scheduler=cfg_scheduler,
        )

        self.size_input = size_input
        self.accuracy = nloss.TopkAccuracy(topk)
        self.cnf = nloss.ConfusionMeter(self.num_output_channels,
                                        normalized=True)
        self.visheatmap = gph.HeatMapVisdom(env_name=self.nameproject)

        # Set the graphic visualization
        self.metrics_name = ['top{}'.format(k) for k in topk]
        self.logger_train = Logger('Trn', ['loss'], self.metrics_name,
                                   self.plotter)
        self.logger_val = Logger('Val', ['loss'], self.metrics_name,
                                 self.plotter)

    def training(self, data_loader, epoch=0):

        self.logger_train.reset()
        data_time = AverageMeter()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.train()

        end = time.time()
        for i, sample in enumerate(data_loader):

            # measure data loading time
            data_time.update(time.time() - end)
            # get data (image, label)
            x, y = sample['image'], sample['label'].argmax(1).long()
            batch_size = x.size(0)

            if self.cuda:
                x = x.cuda()
                y = y.cuda()

            # fit (forward)
            outputs = self.net(x)

            # measure accuracy and record loss
            loss = self.criterion(outputs, y.long())
            pred = self.accuracy(outputs.data, y)

            # optimizer
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # update
            self.logger_train.update(
                {'loss': loss.item()},
                dict(
                    zip(self.metrics_name, [
                        pred[p].item() for p in range(len(self.metrics_name))
                    ])),
                batch_size,
            )

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                self.logger_train.logger(
                    epoch,
                    epoch + float(i + 1) / len(data_loader),
                    i,
                    len(data_loader),
                    batch_time,
                )

    def evaluate(self, data_loader, epoch=0):

        self.logger_val.reset()
        self.cnf.reset()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, sample in enumerate(data_loader):

                # get data (image, label)
                x, y = sample['image'], sample['label'].argmax(1).long()
                batch_size = x.size(0)

                if self.cuda:
                    x = x.cuda()
                    y = y.cuda()

                # fit (forward)
                outputs = self.net(x)

                # measure accuracy and record loss
                loss = self.criterion(outputs, y)
                pred = self.accuracy(outputs.data, y.data)
                self.cnf.add(outputs.argmax(1), y)

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # update
                self.logger_val.update(
                    {'loss': loss.item()},
                    dict(
                        zip(self.metrics_name, [
                            pred[p].item()
                            for p in range(len(self.metrics_name))
                        ])),
                    batch_size,
                )

                if i % self.print_freq == 0:
                    self.logger_val.logger(
                        epoch,
                        epoch,
                        i,
                        len(data_loader),
                        batch_time,
                        bplotter=False,
                        bavg=True,
                        bsummary=False,
                    )

        # save validation loss
        self.vallosses = self.logger_val.info['loss']['loss'].avg
        acc = self.logger_val.info['metrics']['top1'].avg

        self.logger_val.logger(
            epoch,
            epoch,
            i,
            len(data_loader),
            batch_time,
            bplotter=True,
            bavg=True,
            bsummary=True,
        )

        print('Confusion Matriz')
        print(self.cnf.value(), flush=True)
        print('\n')

        self.visheatmap.show('Confusion Matriz', self.cnf.value())
        return acc

    def test(self, data_loader):

        n = len(data_loader) * data_loader.batch_size
        Yhat = np.zeros((n, self.num_output_channels))
        Y = np.zeros((n, 1))
        k = 0

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, sample in enumerate(tqdm(data_loader)):

                # get data (image, label)
                x, y = sample['image'], sample['label'].argmax(1).long()
                x = x.cuda() if self.cuda else x

                # fit (forward)
                yhat = self.net(x)
                yhat = F.softmax(yhat, dim=1)
                yhat = pytutils.to_np(yhat)

                for j in range(yhat.shape[0]):
                    Y[k] = y[j]
                    Yhat[k, :] = yhat[j]
                    k += 1

        Yhat = Yhat[:k, :]
        Y = Y[:k]

        return Yhat, Y

    def predict(self, data_loader):

        n = len(data_loader) * data_loader.batch_size
        Yhat = np.zeros((n, self.num_output_channels))
        Ids = np.zeros((n, 1))
        k = 0

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, (Id, inputs) in enumerate(tqdm(data_loader)):
                x = inputs.cuda() if self.cuda else inputs

                # fit (forward)
                yhat = self.net(x)
                yhat = F.softmax(yhat, dim=1)
                yhat = pytutils.to_np(yhat)

                for j in range(yhat.shape[0]):
                    Yhat[k, :] = yhat[j]
                    Ids[k] = Id[j]
                    k += 1

        Yhat = Yhat[:k, :]
        Ids = Ids[:k]

        return Ids, Yhat

    def __call__(self, image):
        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            x = image.cuda() if self.cuda else image
            msoft = nn.Softmax()
            yhat = msoft(self.net(x))
            yhat = pytutils.to_np(yhat)

        return yhat

    def representation(self, data_loader):
        """"
        Representation
            -data_loader: simple data loader for image
        """

        # switch to evaluate mode
        self.net.eval()

        n = len(data_loader) * data_loader.batch_size
        k = 0

        # embebed features
        embX = np.zeros([n, self.net.dim])
        embY = np.zeros([n, 1])

        batch_time = AverageMeter()
        end = time.time()
        for i, sample in enumerate(data_loader):

            # get data (image, label)
            x, y = sample['image'], sample['label'].argmax(1).long()
            x = x.cuda() if self.cuda else x

            # representation
            emb = self.net.representation(x)
            emb = pytutils.to_np(emb)
            for j in range(emb.shape[0]):
                embX[k, :] = emb[j, :]
                embY[k] = y[j]
                k += 1

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            print(
                'Representation: |{:06d}/{:06d}||{batch_time.val:.3f} ({batch_time.avg:.3f})|'
                .format(i, len(data_loader), batch_time=batch_time))

        embX = embX[:k, :]
        embY = embY[:k]

        return embX, embY

    def _create_model(self, arch, num_output_channels, num_input_channels,
                      pretrained):
        """
        Create model
            @arch (string): select architecture
            @num_classes (int)
            @num_channels (int)
            @pretrained (bool)
        """

        self.net = None
        self.size_input = 0

        kw = {
            'num_classes': num_output_channels,
            'num_channels': num_input_channels,
            'pretrained': pretrained
        }
        self.net = nnmodels.__dict__[arch](**kw)

        self.s_arch = arch
        self.size_input = self.net.size_input
        self.num_output_channels = num_output_channels
        self.num_input_channels = num_input_channels

        if self.cuda == True:
            self.net.cuda()
        if self.parallel == True and self.cuda == True:
            self.net = nn.DataParallel(self.net,
                                       device_ids=range(
                                           torch.cuda.device_count()))

    def _create_loss(self, loss):

        # create loss
        if loss == 'cross':
            self.criterion = nn.CrossEntropyLoss().cuda()
        elif loss == 'mse':
            self.criterion = nn.MSELoss(size_average=True).cuda()
        elif loss == 'l1':
            self.criterion = nn.L1Loss(size_average=True).cuda()
        else:
            assert (False)

        self.s_loss = loss
class ClassNeuralNet(NeuralNetAbstract):
    """
    classification Neural Net like preactresnet

    """
    def __init__(self,
                 patchproject,
                 nameproject,
                 no_cuda=True,
                 parallel=False,
                 seed=1,
                 print_freq=10,
                 gpu=0,
                 view_freq=1):
        """
        Initialization
            -patchproject (str): path project
            -nameproject (str):  name project
            -no_cuda (bool): system cuda (default is True)
            -parallel (bool)
            -seed (int)
            -print_freq (int)
            -gpu (int)
            -view_freq (in epochs)
        """

        super(ClassNeuralNet,
              self).__init__(patchproject, nameproject, no_cuda, parallel,
                             seed, print_freq, gpu)
        self.view_freq = view_freq

    def create(
        self,
        arch,
        num_output_channels,
        num_input_channels,
        loss,
        lr,
        optimizer,
        lrsch,
        momentum=0.9,
        weight_decay=5e-4,
        pretrained=False,
        size_input=388,
        num_classes=8,
        backbone='preactresnet',
        num_filters=32,
        breal='real',
        alpha=2,
        beta=2,
    ):
        """
        Create
            -arch (string): architecture
            -loss (string):
            -lr (float): learning rate
            -optimizer (string) :
            -lrsch (string): scheduler learning rate
            -pretrained (bool)
        """
        cfg_opt = {'momentum': momentum, 'weight_decay': weight_decay}
        #cfg_scheduler={ 'step_size':100, 'gamma':0.1  }
        cfg_scheduler = {'mode': 'min', 'patience': 10}
        cfg_model = {'num_filters': num_filters}

        self.num_classes = num_classes

        super(ClassNeuralNet, self).create(
            arch,
            num_output_channels,
            num_input_channels,
            loss,
            lr,
            optimizer,
            lrsch,
            pretrained,
            cfg_opt=cfg_opt,
            cfg_scheduler=cfg_scheduler,
            cfg_model=cfg_model,
        )
        self.size_input = size_input
        self.backbone = backbone
        self.num_filters = num_filters

        self.topk = nloss.TopkAccuracy()

        self.logger_train = Logger('Train', ['loss', 'loss_bce'], ['topk'],
                                   self.plotter)
        self.logger_val = Logger('Val  ', ['loss', 'loss_bce'], ['topk'],
                                 self.plotter)
        self.breal = breal

        # Set the graphic visualization
        # self.visheatmap = gph.HeatMapVisdom(env_name=self.nameproject, heatsize=(100,100) )

    def _create_model(self, arch, num_output_channels, num_input_channels,
                      pretrained, **kwargs):
        """
        Create model
            -arch (string): select architecture
            -num_classes (int)
            -num_channels (int)
            -pretrained (bool)
        """

        self.net = None

        #--------------------------------------------------------------------------------------------
        # select architecture
        #--------------------------------------------------------------------------------------------
        num_filters = kwargs.get("num_filters", 32)
        # num_classes=1000, num_channels=3, initial_channels=64 for preactresnet

        kw = {
            'num_classes': self.num_classes,
            'num_channels': num_input_channels,
            'pretrained': pretrained
        }
        print("kw", kw)
        self.net = nnmodels.__dict__[arch](**kw)

        self.s_arch = arch
        self.num_output_channels = num_output_channels
        self.num_input_channels = num_input_channels
        self.num_filters = num_filters
        if self.cuda:
            self.net.cuda()
        if self.parallel and self.cuda:
            self.net = nn.DataParallel(self.net,
                                       device_ids=range(
                                           torch.cuda.device_count()))

    def save(self, epoch, prec, is_best=False, filename='checkpoint.pth.tar'):
        """
        Save model
        """
        print('>> save model epoch {} ({}) in {}'.format(
            epoch, prec, filename))
        net = self.net.module if self.parallel else self.net
        pytutils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': self.s_arch,
                'imsize': self.size_input,
                'num_output_channels': self.num_output_channels,
                'num_input_channels': self.num_input_channels,
                'num_classes': self.num_classes,
                'state_dict': net.state_dict(),
                'prec': prec,
                'optimizer': self.optimizer.state_dict(),
            }, is_best, self.pathmodels, filename)

    def load(self, pathnamemodel):
        """
        load model from pretrained model
        :param pathnamemodel: model path
        :return:
        """
        bload = False
        if pathnamemodel:
            if os.path.isfile(pathnamemodel):
                print("=> loading checkpoint '{}'".format(pathnamemodel))
                checkpoint = torch.load(
                    pathnamemodel) if self.cuda else torch.load(
                        pathnamemodel,
                        map_location=lambda storage, loc: storage)
                self.num_classes = checkpoint['num_classes']
                self._create_model(checkpoint['arch'],
                                   checkpoint['num_output_channels'],
                                   checkpoint['num_input_channels'], False)
                self.size_input = checkpoint['imsize']
                self.net.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint for {} arch!".format(
                    checkpoint['arch']))
                bload = True
            else:
                print("=> no checkpoint found at '{}'".format(pathnamemodel))
        return bload

    def training(self, data_loader, epoch=0, *args):
        #reset logger
        self.logger_train.reset()
        data_time = AverageMeter()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.train()

        end = time.time()
        for i, sample in enumerate(data_loader):

            # measure data loading time
            data_time.update(time.time() - end)
            # if dataset is real
            if self.breal == 'real':
                x_img, y_lab = sample['image'], sample['label']
                y_lab = y_lab.argmax(dim=1)

                if self.cuda:
                    x_img = x_img.cuda()
                    y_lab = y_lab.cuda()

                x_org = x_img.clone().detach()

            else:
                # if dataset is synthetic
                x_org, x_img, y_mask, meta = sample

                y_lab = meta[:, 0]
                y_theta = meta[:, 1:].view(-1, 2, 3)

                if self.cuda:
                    x_org = x_org.cuda()
                    x_img = x_img.cuda()
                    y_mask = y_mask.cuda()
                    y_lab = y_lab.cuda()
                    y_theta = y_theta.cuda()

            # fit (forward)
            y_lab_hat = self.net(x_img)

            # calculate classification loss
            loss_bce = self.criterion_bce(y_lab_hat, y_lab.long())
            loss = loss_bce
            # accuracy of choosing top k predicted classes
            topk = self.topk(y_lab_hat, y_lab.long())

            batch_size = x_img.shape[0]

            # optimizer
            self.optimizer.zero_grad()
            (loss * batch_size).backward()
            self.optimizer.step()

            # update
            self.logger_train.update(
                {
                    'loss': loss.cpu().item(),
                    'loss_bce': loss_bce.cpu().item()
                },
                {'topk': topk[0][0].cpu()},
                batch_size,
            )

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                self.logger_train.logger(
                    epoch,
                    epoch + float(i + 1) / len(data_loader),
                    i,
                    len(data_loader),
                    batch_time,
                )

    def evaluate(self, data_loader, epoch=0, *args):
        """
        evaluate on validation dataset
        :param data_loader: which data_loader to use
        :param epoch: current epoch
        :return:
            acc: average accuracy on data_loader
        """
        # reset loader
        self.logger_val.reset()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, sample in enumerate(data_loader):

                # get data (image, label)
                if self.breal == 'real':
                    x_img, y_lab = sample["image"], sample["label"]
                    y_lab = y_lab.argmax(dim=1)

                    if self.cuda:
                        x_img = x_img.cuda()
                        y_lab = y_lab.cuda()
                    x_org = x_img.clone().detach()
                else:
                    x_org, x_img, y_mask, meta = sample

                    y_lab = meta[:, 0]
                    y_theta = meta[:, 1:].view(-1, 2, 3)

                    if self.cuda:
                        x_org = x_org.cuda()
                        x_img = x_img.cuda()
                        y_mask = y_mask.cuda()
                        y_lab = y_lab.cuda()
                        y_theta = y_theta.cuda()

                # fit (forward)
                # print("x_img size", x_img.size())
                y_lab_hat = self.net(x_img)

                # measure accuracy and record loss
                loss_bce = self.criterion_bce(y_lab_hat, y_lab.long())
                loss = loss_bce
                topk = self.topk(y_lab_hat, y_lab.long())

                batch_size = x_img.shape[0]

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # update
                self.logger_val.update(
                    {
                        'loss': loss.cpu().item(),
                        'loss_bce': loss_bce.cpu().item()
                    },
                    {'topk': topk[0][0].cpu()},
                    batch_size,
                )

                # print the result in certain print frequency when iterating batches
                if i % self.print_freq == 0:
                    self.logger_val.logger(
                        epoch,
                        epoch,
                        i,
                        len(data_loader),
                        batch_time,
                        bplotter=False,
                        bavg=True,
                        bsummary=False,
                    )

        #save validation loss and accuracy
        self.vallosses = self.logger_val.info['loss']['loss'].avg
        acc = self.logger_val.info['metrics']['topk'].avg

        # print the average loss and accuracy after one iteration
        self.logger_val.logger(
            epoch,
            epoch,
            i,
            len(data_loader),
            batch_time,
            bplotter=True,
            bavg=True,
            bsummary=True,
        )

        return acc

    def representation(self, dataloader, breal='real'):
        """
        :param dataloader:
        :param breal:'real' or 'synthetic'
        :return:
            Y_labs: true labels
            Y_lab_hats: predicted labels
        """
        Y_labs = []
        Y_lab_hats = []
        self.net.eval()
        with torch.no_grad():
            for i_batch, sample in enumerate(tqdm(dataloader)):

                if breal == 'real':
                    x_img, y_lab = sample['image'], sample['label']
                    y_lab = y_lab.argmax(dim=1)
                else:
                    x_org, x_img, y_mask, y_lab = sample
                    y_lab = y_lab[:, 0]

                if self.cuda:
                    x_img = x_img.cuda()
                y_lab_hat = self.net(x_img)
                Y_labs.append(y_lab)
                Y_lab_hats.append(y_lab_hat.data.cpu())

        Y_labs = np.concatenate(Y_labs, axis=0)
        Y_lab_hats = np.concatenate(Y_lab_hats, axis=0)
        return Y_labs, Y_lab_hats

    def __call__(self, image):
        # when calling the class, switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            x = image.cuda() if self.cuda else image
            y_lab_hat, att, fmap, srf = self.net(x)
            y_lab_hat = F.softmax(y_lab_hat, dim=1)
        return y_lab_hat, att, fmap, srf

    def _create_loss(self, loss, alpha=None, beta=None):
        # private method
        # create cross entropy loss
        self.criterion_bce = nn.CrossEntropyLoss().cuda()
        self.s_loss = loss
class NeuralNetClassifier(NeuralNetAbstract):
    r"""Convolutional Neural Net for classification
    Args:
        patchproject (str): path project
        nameproject (str):  name project
        no_cuda (bool): system cuda (default is True)
        parallel (bool)
        seed (int)
        print_freq (int)
        gpu (int)
    """
    def __init__(self,
                 patchproject,
                 nameproject,
                 no_cuda=True,
                 parallel=False,
                 seed=1,
                 print_freq=10,
                 gpu=0):
        super(NeuralNetClassifier,
              self).__init__(patchproject, nameproject, no_cuda, parallel,
                             seed, print_freq, gpu)

    def create(
        self,
        arch,
        num_output_channels,
        num_input_channels,
        loss,
        lr,
        optimizer,
        lrsch,
        momentum=0.9,
        weight_decay=5e-4,
        pretrained=False,
        th=0.0,
        size_input=32,
    ):
        """
        Create
        Args:
            arch (string): architecture
            num_output_channels, 
            num_input_channels,  
            loss (string):
            lr (float): learning rate
            momentum,
            optimizer (string) : 
            lrsch (string): scheduler learning rate
            pretrained (bool)
        """

        cfg_opt = {'momentum': 0.9, 'weight_decay': 5e-4}
        cfg_scheduler = {'step_size': 50, 'gamma': 0.1}

        super(NeuralNetClassifier, self).create(
            arch,
            num_output_channels,
            num_input_channels,
            loss,
            lr,
            optimizer,
            lrsch,
            pretrained,
            cfg_opt=cfg_opt,
            cfg_scheduler=cfg_scheduler,
        )

        self.size_input = size_input
        self.accuracy = nloss.MultAccuracyV1(th)
        self.f_score = nloss.F_score(threshold=th, beta=2)

        #self.cnf = nloss.ConfusionMeter( self.num_output_channels, normalized=True )
        #self.visheatmap = gph.HeatMapVisdom( env_name=self.nameproject )

        # Set the graphic visualization
        self.logger_train = Logger('Trn', ['loss'], ['acc', 'f1'],
                                   self.plotter)
        self.logger_val = Logger('Val', ['loss'], ['acc', 'f1'], self.plotter)

    def training(self, data_loader, epoch=0):

        self.logger_train.reset()
        data_time = AverageMeter()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.train()

        end = time.time()
        for i, (iD, image, prob) in enumerate(data_loader):

            # measure data loading time
            data_time.update(time.time() - end)
            x, y = image, prob
            batch_size = x.size(0)

            if self.cuda:
                x = x.cuda()
                y = y.cuda()

            # fit (forward)
            yhat = self.net(x)

            # measure accuracy and record loss
            loss = self.criterion(yhat, y.float())
            pred = self.accuracy(yhat.data, y.data)
            f1 = self.f_score(yhat.data, y.data)

            # optimizer
            self.optimizer.zero_grad()
            (loss).backward()
            self.optimizer.step()

            # update
            self.logger_train.update(
                {'loss': loss.data[0]},
                {
                    'acc': pred.data[0],
                    'f1': f1.data[0]
                },
                batch_size,
            )

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                self.logger_train.logger(
                    epoch,
                    epoch + float(i + 1) / len(data_loader),
                    i,
                    len(data_loader),
                    batch_time,
                )

    def evaluate(self, data_loader, epoch=0):

        self.logger_val.reset()
        #self.cnf.reset()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, (iD, image, prob) in enumerate(data_loader):

                # get data (image, label)
                x, y = image, prob  #.argmax(1).long()
                batch_size = x.size(0)

                if self.cuda:
                    x = x.cuda()
                    y = y.cuda()

                # fit (forward)
                yhat = self.net(x)

                # measure accuracy and record loss
                loss = self.criterion(yhat, y.float())
                pred = self.accuracy(yhat.data, y.data)
                f1 = self.f_score(yhat.data, y.data)

                #self.cnf.add( outputs.argmax(1), y )

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # update
                self.logger_val.update(
                    {'loss': loss.data[0]},
                    {
                        'acc': pred.data[0],
                        'f1': f1.data[0]
                    },
                    batch_size,
                )

                if i % self.print_freq == 0:
                    self.logger_val.logger(
                        epoch,
                        epoch,
                        i,
                        len(data_loader),
                        batch_time,
                        bplotter=False,
                        bavg=True,
                        bsummary=False,
                    )

        #save validation loss
        self.vallosses = self.logger_val.info['loss']['loss'].avg
        acc = self.logger_val.info['metrics']['acc'].avg

        self.logger_val.logger(
            epoch,
            epoch,
            i,
            len(data_loader),
            batch_time,
            bplotter=True,
            bavg=True,
            bsummary=True,
        )

        #print('Confusion Matriz')
        #print(self.cnf.value(), flush=True)
        #print('\n')
        #self.visheatmap.show('Confusion Matriz', self.cnf.value())

        return acc

    def predict(self, data_loader):
        Yhat, iDs, Y = [], [], []
        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, (iD, image, prob) in enumerate(tqdm(data_loader)):
                x = image.cuda() if self.cuda else image
                yhat = self.net(x)
                yhat = F.sigmoid(yhat).cpu().numpy()
                Yhat.append(yhat)
                iDs.append(iD)
                Y.append(prob)
        Yhat = np.concatenate(Yhat, axis=0)
        iDs = np.concatenate(iDs, axis=0)
        Y = np.concatenate(Y, axis=0)
        return iDs, Yhat, Y

    def __call__(self, image):

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            x = image.cuda() if self.cuda else image
            yhat = self.net(x)
            yhat = F.sigmoid(yhat).cpu().numpy()
        return yhat

    def _create_model(self, arch, num_output_channels, num_input_channels,
                      pretrained):
        """
        Create model
            @arch (string): select architecture
            @num_classes (int)
            @num_channels (int)
            @pretrained (bool)
        """
        self.net = None
        self.size_input = 0

        kw = {
            'num_classes': num_output_channels,
            'num_channels': num_input_channels,
            'pretrained': pretrained
        }
        self.net = nnmodels.__dict__[arch](**kw)

        self.s_arch = arch
        self.size_input = self.net.size_input
        self.num_output_channels = num_output_channels
        self.num_input_channels = num_input_channels

        if self.cuda == True:
            self.net.cuda()
        if self.parallel == True and self.cuda == True:
            self.net = nn.DataParallel(self.net,
                                       device_ids=range(
                                           torch.cuda.device_count()))

    def _create_loss(self, loss):

        # create loss
        if loss == 'cross':
            self.criterion = nn.CrossEntropyLoss().cuda()
        elif loss == 'mse':
            self.criterion = nn.MSELoss(size_average=True).cuda()
        elif loss == 'bcewl':
            self.criterion = nn.BCEWithLogitsLoss(size_average=True).cuda()
        elif loss == 'l1':
            self.criterion = nn.L1Loss(size_average=True).cuda()
        elif loss == 'focal':
            self.criterion = nloss.FocalLoss(gamma=2).cuda()
        elif loss == 'dice':
            self.criterion = nloss.DiceLoss().cuda()
        elif loss == 'mix':
            self.criterion = nloss.MixLoss().cuda()
        else:
            assert (False)

        self.s_loss = loss
Exemple #10
0
class NeuralNetClassifier(NeuralNetAbstract):
    """
    Convolutional Neural Net 
    """
    def __init__(self,
                 patchproject,
                 nameproject,
                 no_cuda=True,
                 parallel=False,
                 seed=1,
                 print_freq=10,
                 gpu=0):
        """
        Initialization
        Args:
            @patchproject (str): path project
            @nameproject (str):  name project
            @no_cuda (bool): system cuda (default is True)
            @parallel (bool)
            @seed (int)
            @print_freq (int)
            @gpu (int)
        """

        super(NeuralNetClassifier,
              self).__init__(patchproject, nameproject, no_cuda, parallel,
                             seed, print_freq, gpu)

    def create(
            self,
            arch,
            num_output_channels,
            num_input_channels,
            loss,
            lr,
            momentum,
            optimizer,
            lrsch,
            pretrained=False,
            topk=(1, ),
    ):
        """
        Create
        Args:
            @arch (string): architecture
            @num_output_channels, 
            @num_input_channels,  
            @loss (string):
            @lr (float): learning rate
            @momentum,
            @optimizer (string) : 
            @lrsch (string): scheduler learning rate
            @pretrained (bool)
        """
        super(NeuralNetClassifier,
              self).create(arch, num_output_channels, num_input_channels, loss,
                           lr, momentum, optimizer, lrsch, pretrained)
        self.accuracy = nloss.Accuracy(topk)
        self.cnf = nloss.ConfusionMeter(self.num_output_channels,
                                        normalized=True)
        self.visheatmap = gph.HeatMapVisdom(env_name=self.nameproject)

        # Set the graphic visualization
        self.metrics_name = ['top{}'.format(k) for k in topk]
        self.logger_train = Logger('Trn', ['loss'], self.metrics_name,
                                   self.plotter)
        self.logger_val = Logger('Val', ['loss'], self.metrics_name,
                                 self.plotter)

    def training(self, data_loader, epoch=0):

        self.logger_train.reset()
        data_time = AverageMeter()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.train()

        end = time.time()
        for i, sample in enumerate(data_loader):

            # measure data loading time
            data_time.update(time.time() - end)
            # get data (image, label)
            inputs, targets = sample['image'], pytutils.argmax(sample['label'])
            batch_size = inputs.size(0)

            if self.cuda:
                targets = targets.cuda(non_blocking=True)
                inputs_var = Variable(inputs.cuda(), requires_grad=False)
                targets_var = Variable(targets.cuda(), requires_grad=False)
            else:
                inputs_var = Variable(inputs, requires_grad=False)
                targets_var = Variable(targets, requires_grad=False)

            # fit (forward)
            outputs = self.net(inputs_var)

            # measure accuracy and record loss
            loss = self.criterion(outputs, targets_var)
            pred = self.accuracy(outputs.data, targets)

            # optimizer
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # update
            self.logger_train.update(
                {'loss': loss.data[0]},
                dict(
                    zip(self.metrics_name,
                        [pred[p][0] for p in range(len(self.metrics_name))])),
                batch_size,
            )

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                self.logger_train.logger(
                    epoch,
                    epoch + float(i + 1) / len(data_loader),
                    i,
                    len(data_loader),
                    batch_time,
                )

    def evaluate(self, data_loader, epoch=0):

        self.logger_val.reset()
        self.cnf.reset()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, sample in enumerate(data_loader):

                # get data (image, label)
                inputs, targets = sample['image'], pytutils.argmax(
                    sample['label'])
                batch_size = inputs.size(0)

                if self.cuda:
                    targets = targets.cuda(non_blocking=True)
                    inputs_var = Variable(inputs.cuda(),
                                          requires_grad=False,
                                          volatile=True)
                    targets_var = Variable(targets.cuda(),
                                           requires_grad=False,
                                           volatile=True)
                else:
                    inputs_var = Variable(inputs,
                                          requires_grad=False,
                                          volatile=True)
                    targets_var = Variable(targets,
                                           requires_grad=False,
                                           volatile=True)

                # fit (forward)
                outputs = self.net(inputs_var)

                # measure accuracy and record loss
                loss = self.criterion(outputs, targets_var)
                pred = self.accuracy(outputs.data, targets)
                self.cnf.add(outputs.argmax(1), targets_var)

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # update
                self.logger_val.update(
                    {'loss': loss.data[0]},
                    dict(
                        zip(self.metrics_name, [
                            pred[p][0] for p in range(len(self.metrics_name))
                        ])),
                    batch_size,
                )

                if i % self.print_freq == 0:
                    self.logger_val.logger(
                        epoch,
                        epoch,
                        i,
                        len(data_loader),
                        batch_time,
                        bplotter=False,
                        bavg=True,
                        bsummary=False,
                    )

        #save validation loss
        self.vallosses = self.logger_val.info['loss']['loss'].avg
        acc = self.logger_val.info['metrics']['top1'].avg

        self.logger_val.logger(
            epoch,
            epoch,
            i,
            len(data_loader),
            batch_time,
            bplotter=True,
            bavg=True,
            bsummary=True,
        )

        print('Confusion Matriz')
        print(self.cnf.value(), flush=True)
        print('\n')

        self.visheatmap.show('Confusion Matriz', self.cnf.value())
        return acc

    #def _to_end_epoch(self, epoch, epochs, train_loader, val_loader):
    #print('>> Reset', flush=True )
    #w = 1-self.cnf.value().diagonal()
    #train_loader.dataset.reset(w)

    def test(self, data_loader):

        n = len(data_loader) * data_loader.batch_size
        Yhat = np.zeros((n, self.num_output_channels))
        Y = np.zeros((n, 1))
        k = 0

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, sample in enumerate(tqdm(data_loader)):

                # get data (image, label)
                inputs = sample['image']
                targets = pytutils.argmax(sample['label'])

                x = inputs.cuda() if self.cuda else inputs
                x = Variable(x, requires_grad=False, volatile=True)

                # fit (forward)
                yhat = self.net(x)
                yhat = F.softmax(yhat, dim=1)
                yhat = pytutils.to_np(yhat)

                for j in range(yhat.shape[0]):
                    Y[k] = targets[j]
                    Yhat[k, :] = yhat[j]
                    k += 1

                #print( 'Test:', i , flush=True )

        Yhat = Yhat[:k, :]
        Y = Y[:k]

        return Yhat, Y

    def predict(self, data_loader):

        n = len(data_loader) * data_loader.batch_size
        Yhat = np.zeros((n, self.num_output_channels))
        Ids = np.zeros((n, 1))
        k = 0

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, (Id, inputs) in enumerate(tqdm(data_loader)):

                # get data (image, label)
                #inputs = sample['image']
                #Id = sample['id']

                x = inputs.cuda() if self.cuda else inputs
                x = Variable(x, requires_grad=False, volatile=True)

                # fit (forward)
                yhat = self.net(x)
                yhat = F.softmax(yhat, dim=1)
                yhat = pytutils.to_np(yhat)

                for j in range(yhat.shape[0]):
                    Yhat[k, :] = yhat[j]
                    Ids[k] = Id[j]
                    k += 1

        Yhat = Yhat[:k, :]
        Ids = Ids[:k]

        return Ids, Yhat

    def __call__(self, image):

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            x = image.cuda() if self.cuda else image
            x = Variable(x, requires_grad=False, volatile=True)
            msoft = nn.Softmax()
            yhat = msoft(self.net(x))
            yhat = pytutils.to_np(yhat)

        return yhat

    def representation(self, data_loader):
        """"
        Representation
            -data_loader: simple data loader for image
        """

        # switch to evaluate mode
        self.net.eval()

        n = len(data_loader) * data_loader.batch_size
        k = 0

        # embebed features
        embX = np.zeros([n, self.net.dim])
        embY = np.zeros([n, 1])

        batch_time = AverageMeter()
        end = time.time()
        for i, sample in enumerate(data_loader):

            # get data (image, label)
            inputs, targets = sample['image'], pytutils.argmax(sample['label'])
            inputs_var = pytutils.to_var(inputs, self.cuda, False, True)

            # representation
            emb = self.net.representation(inputs_var)
            emb = pytutils.to_np(emb)
            for j in range(emb.shape[0]):
                embX[k, :] = emb[j, :]
                embY[k] = targets[j]
                k += 1

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            print(
                'Representation: |{:06d}/{:06d}||{batch_time.val:.3f} ({batch_time.avg:.3f})|'
                .format(i, len(data_loader), batch_time=batch_time))

        embX = embX[:k, :]
        embY = embY[:k]

        return embX, embY

    def _create_model(self, arch, num_output_channels, num_input_channels,
                      pretrained):
        """
        Create model
            @arch (string): select architecture
            @num_classes (int)
            @num_channels (int)
            @pretrained (bool)
        """

        self.net = None
        self.size_input = 0

        kw = {
            'num_classes': num_output_channels,
            'num_channels': num_input_channels,
            'pretrained': pretrained
        }
        self.net = nnmodels.__dict__[arch](**kw)

        self.s_arch = arch
        self.size_input = self.net.size_input
        self.num_output_channels = num_output_channels
        self.num_input_channels = num_input_channels

        if self.cuda == True:
            self.net.cuda()
        if self.parallel == True and self.cuda == True:
            self.net = nn.DataParallel(self.net,
                                       device_ids=range(
                                           torch.cuda.device_count()))

    def _create_loss(self, loss):

        # create loss
        if loss == 'cross':
            self.criterion = nn.CrossEntropyLoss().cuda()
        elif loss == 'mse':
            self.criterion = nn.MSELoss(size_average=True).cuda()
        elif loss == 'l1':
            self.criterion = nn.L1Loss(size_average=True).cuda()
        else:
            assert (False)

        self.s_loss = loss
class SegmentationNeuralNet(NeuralNetAbstract):
    """
    Segmentation Neural Net 
    """
    def __init__(self,
                 patchproject,
                 nameproject,
                 no_cuda=True,
                 parallel=False,
                 seed=1,
                 print_freq=10,
                 gpu=0,
                 view_freq=1):
        """
        Initialization
            -patchproject (str): path project
            -nameproject (str):  name project
            -no_cuda (bool): system cuda (default is True)
            -parallel (bool)
            -seed (int)
            -print_freq (int)
            -gpu (int)
            -view_freq (in epochs)
        """

        super(SegmentationNeuralNet,
              self).__init__(patchproject, nameproject, no_cuda, parallel,
                             seed, print_freq, gpu)
        self.view_freq = view_freq

    def create(
        self,
        arch,
        num_output_channels,
        num_input_channels,
        loss,
        lr,
        momentum,
        optimizer,
        lrsch,
        pretrained=False,
        size_input=388,
    ):
        """
        Create            
            -arch (string): architecture
            -loss (string):
            -lr (float): learning rate
            -optimizer (string) : 
            -lrsch (string): scheduler learning rate
            -pretrained (bool)
        """
        super(SegmentationNeuralNet,
              self).create(arch, num_output_channels, num_input_channels, loss,
                           lr, momentum, optimizer, lrsch, pretrained)
        self.size_input = size_input

        self.accuracy = nloss.Accuracy()
        self.dice = nloss.Dice()

        # Set the graphic visualization
        self.logger_train = Logger('Train', ['loss'], ['accs', 'dices'],
                                   self.plotter)
        self.logger_val = Logger('Val  ', ['loss'], ['accs', 'dices'],
                                 self.plotter)

        self.visheatmap = gph.HeatMapVisdom(env_name=self.nameproject,
                                            heatsize=(100, 100))
        self.visimshow = gph.ImageVisdom(env_name=self.nameproject,
                                         imsize=(100, 100))

    def training(self, data_loader, epoch=0):

        #reset logger
        self.logger_train.reset()
        data_time = AverageMeter()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.train()

        end = time.time()
        for i, sample in enumerate(data_loader):

            # measure data loading time
            data_time.update(time.time() - end)
            # get data (image, label, weight)
            inputs, targets, weights = sample['image'], sample[
                'label'], sample['weight']
            batch_size = inputs.size(0)

            if self.cuda:
                targets = targets.cuda(non_blocking=True)
                inputs_var = Variable(inputs.cuda(), requires_grad=False)
                targets_var = Variable(targets.cuda(), requires_grad=False)
                weights_var = Variable(weights.cuda(), requires_grad=False)
            else:
                inputs_var = Variable(inputs, requires_grad=False)
                targets_var = Variable(targets, requires_grad=False)
                weights_var = Variable(weights, requires_grad=False)

            # fit (forward)
            outputs = self.net(inputs_var)

            # measure accuracy and record loss
            loss = self.criterion(outputs, targets_var, weights_var)
            accs = self.accuracy(outputs, targets_var)
            dices = self.dice(outputs, targets_var)

            # optimizer
            self.optimizer.zero_grad()
            (loss * batch_size).backward()
            self.optimizer.step()

            # update
            self.logger_train.update(
                {'loss': loss.data[0]},
                {
                    'accs': accs,
                    'dices': dices
                },
                batch_size,
            )

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                self.logger_train.logger(
                    epoch,
                    epoch + float(i + 1) / len(data_loader),
                    i,
                    len(data_loader),
                    batch_time,
                )

    def evaluate(self, data_loader, epoch=0):

        # reset loader
        self.logger_val.reset()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, sample in enumerate(data_loader):

                # get data (image, label)
                inputs, targets, weights = sample['image'], sample[
                    'label'], sample['weight']
                batch_size = inputs.size(0)

                if self.cuda:
                    targets = targets.cuda(non_blocking=True)
                    inputs_var = Variable(inputs.cuda(),
                                          requires_grad=False,
                                          volatile=True)
                    targets_var = Variable(targets.cuda(),
                                           requires_grad=False,
                                           volatile=True)
                    weights_var = Variable(weights.cuda(),
                                           requires_grad=False,
                                           volatile=True)
                else:
                    inputs_var = Variable(inputs,
                                          requires_grad=False,
                                          volatile=True)
                    targets_var = Variable(targets,
                                           requires_grad=False,
                                           volatile=True)
                    weights_var = Variable(weights,
                                           requires_grad=False,
                                           volatile=True)

                # fit (forward)
                outputs = self.net(inputs_var)

                # measure accuracy and record loss
                loss = self.criterion(outputs, targets_var, weights_var)
                accs = self.accuracy(outputs, targets_var)
                dices = self.dice(outputs, targets_var)

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # update
                self.logger_val.update(
                    {'loss': loss.data[0]},
                    {
                        'accs': accs,
                        'dices': dices
                    },
                    batch_size,
                )

                if i % self.print_freq == 0:
                    self.logger_val.logger(
                        epoch,
                        epoch,
                        i,
                        len(data_loader),
                        batch_time,
                        bplotter=False,
                        bavg=True,
                        bsummary=False,
                    )

        #save validation loss
        self.vallosses = self.logger_val.info['loss']['loss'].avg
        acc = self.logger_val.info['metrics']['accs'].avg

        self.logger_val.logger(
            epoch,
            epoch,
            i,
            len(data_loader),
            batch_time,
            bplotter=True,
            bavg=True,
            bsummary=True,
        )

        #vizual_freq
        if epoch % self.view_freq == 0:

            prob = F.softmax(outputs, dim=1)
            prob = prob.data[0]
            _, maxprob = torch.max(prob, 0)

            self.visheatmap.show('Label',
                                 targets_var.data.cpu()[0].numpy()[1, :, :])
            self.visheatmap.show('Weight map',
                                 weights_var.data.cpu()[0].numpy()[0, :, :])
            self.visheatmap.show('Image',
                                 inputs_var.data.cpu()[0].numpy()[0, :, :])
            self.visheatmap.show('Max prob',
                                 maxprob.cpu().numpy().astype(np.float32))
            for k in range(prob.shape[0]):
                self.visheatmap.show('Heat map {}'.format(k),
                                     prob.cpu()[k].numpy())

        return acc

    def test(self, data_loader):

        masks = []
        ids = []
        k = 0

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, sample in enumerate(tqdm(data_loader)):

                # get data (image, label)
                inputs, meta = sample['image'], sample['metadata']
                idd = meta[:, 0]
                x = inputs.cuda() if self.cuda else inputs

                # fit (forward)
                yhat = self.net(x)
                yhat = F.softmax(yhat, dim=1)
                yhat = pytutils.to_np(yhat)

                masks.append(yhat)
                ids.append(idd)

        return ids, masks

    def __call__(self, image):

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            x = image.cuda() if self.cuda else image
            yhat = F.softmax(self.net(x), dim=1)
            yhat = pytutils.to_np(yhat).transpose(2, 3, 1, 0)[..., 0]

        return yhat

    def _create_model(self, arch, num_output_channels, num_input_channels,
                      pretrained):
        """
        Create model
            -arch (string): select architecture
            -num_classes (int)
            -num_channels (int)
            -pretrained (bool)
        """

        self.net = None

        #--------------------------------------------------------------------------------------------
        # select architecture
        #--------------------------------------------------------------------------------------------
        #kw = {'num_classes': num_output_channels, 'num_channels': num_input_channels, 'pretrained': pretrained}

        kw = {
            'num_classes': num_output_channels,
            'in_channels': num_input_channels,
            'pretrained': pretrained
        }
        self.net = nnmodels.__dict__[arch](**kw)

        self.s_arch = arch
        self.num_output_channels = num_output_channels
        self.num_input_channels = num_input_channels

        if self.cuda == True:
            self.net.cuda()
        if self.parallel == True and self.cuda == True:
            self.net = nn.DataParallel(self.net,
                                       device_ids=range(
                                           torch.cuda.device_count()))

    def _create_loss(self, loss):

        # create loss
        if loss == 'wmce':
            self.criterion = nloss.WeightedMCEloss()
        elif loss == 'bdice':
            self.criterion = nloss.BDiceLoss()
        elif loss == 'wbdice':
            self.criterion = nloss.WeightedBDiceLoss()
        elif loss == 'wmcedice':
            self.criterion = nloss.WeightedMCEDiceLoss()
        elif loss == 'wfocalmce':
            self.criterion = nloss.WeightedMCEFocalloss()
        elif loss == 'mcedice':
            self.criterion = nloss.MCEDiceLoss()
        else:
            assert (False)

        self.s_loss = loss
Exemple #12
0
class AttentionGMMNeuralNet(AttentionNeuralNetAbstract):
    """
    Attention Neural Net and GMM representation 
    Args:
        -patchproject (str): path project
        -nameproject (str):  name project
        -no_cuda (bool): system cuda (default is True)
        -parallel (bool)
        -seed (int)
        -print_freq (int)
        -gpu (int)
        -view_freq (in epochs)
    """
    def __init__(self,
        patchproject,
        nameproject,
        no_cuda=True,
        parallel=False,
        seed=1,
        print_freq=10,
        gpu=0,
        view_freq=1
        ):
        super(AttentionGMMNeuralNet, self).__init__( patchproject, nameproject, no_cuda, parallel, seed, print_freq, gpu, view_freq  )
        

 
    def create(self, 
        arch, 
        num_output_channels,
        num_input_channels,        
        loss,
        lr,
        optimizer,
        lrsch,
        momentum=0.9,
        weight_decay=5e-4,
        pretrained=False,
        size_input=388,
        num_classes=8,
        ):
        """
        Create    
        Args:        
            -arch (string): architecture
            -num_output_channels,
            -num_input_channels,  
            -loss (string):
            -lr (float): learning rate
            -optimizer (string) : 
            -lrsch (string): scheduler learning rate
            -pretrained (bool)
            -
        """        
        super(AttentionGMMNeuralNet, self).create( 
            arch, 
            num_output_channels,
            num_input_channels,        
            loss,
            lr,
            optimizer,
            lrsch,
            momentum,
            weight_decay,
            pretrained,
            size_input,
            num_classes,          
        )

        self.logger_train = Logger( 'Train', ['loss', 'loss_gmm', 'loss_bce', 'loss_att' ], [ 'topk', 'gmm'], self.plotter  )
        self.logger_val   = Logger( 'Val  ', ['loss', 'loss_gmm', 'loss_bce', 'loss_att' ], [ 'topk', 'gmm'], self.plotter )
        
      
    def training(self, data_loader, epoch=0):        

        #reset logger
        self.logger_train.reset()
        data_time = AverageMeter()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.train()

        end = time.time()
        for i, (x_org, x_img, y_mask, meta ) in enumerate(data_loader):
            
            # measure data loading time
            data_time.update(time.time() - end)
            batch_size = x_img.shape[0]
            
            y_lab = meta[:,0]
            y_theta   = meta[:,1:].view(-1, 2, 3)            

            if self.cuda:
                x_org   = x_org.cuda()
                x_img   = x_img.cuda() 
                y_mask  = y_mask.cuda() 
                y_lab   = y_lab.cuda()
                y_theta = y_theta.cuda()
            
            # fit (forward)            
            z, y_lab_hat, att, _, _ = self.net( x_img, x_img*y_mask[:,1,...].unsqueeze(dim=1) )                
            
            # measure accuracy and record loss           
            loss_bce  = self.criterion_bce(  y_lab_hat, y_lab.long() )
            loss_gmm  = self.criterion_gmm(  z, y_lab )              
            loss_att  = self.criterion_att(  x_org, y_mask, att )            
            loss      = loss_bce + loss_gmm + loss_att           
            topk      = self.topk( y_lab_hat, y_lab.long() )
            gmm       = self.gmm( z, y_lab )            
            
            # optimizer
            self.optimizer.zero_grad()
            (loss).backward() #batch_size
            self.optimizer.step()
            
            # update
            self.logger_train.update(
                {'loss': loss.cpu().item(), 'loss_gmm': loss_gmm.cpu().item(), 'loss_bce': loss_bce.cpu().item(), 'loss_att':loss_att.cpu().item() },
                {'topk': topk[0][0].cpu(), 'gmm': gmm.cpu().item() },
                batch_size,
                )
            
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            
            if i % self.print_freq == 0:  
                self.logger_train.logger( epoch, epoch + float(i+1)/len(data_loader), i, len(data_loader), batch_time,   )


    def evaluate(self, data_loader, epoch=0):
        
        # reset loader
        self.logger_val.reset()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, (x_org, x_img, y_mask, meta) in enumerate(data_loader):
                
                # get data (image, label)
                batch_size = x_img.shape[0]    
                                               
                y_lab = meta[:,0]
                y_theta   = meta[:,1:].view(-1, 2, 3)
                                
                if self.cuda:
                    x_org   = x_org.cuda()
                    x_img   = x_img.cuda()
                    y_mask  = y_mask.cuda()
                    y_lab   = y_lab.cuda()
                    y_theta = y_theta.cuda()
                
                # fit (forward)            
                z, y_lab_hat, att, fmap, srf  = self.net( x_img )                 
                
                # measure accuracy and record loss       
                loss_bce  = self.criterion_bce( y_lab_hat, y_lab.long() )
                loss_gmm  = self.criterion_gmm( z, y_lab )
                loss_att  = self.criterion_att( x_org, y_mask, att  )
                loss      = loss_bce + loss_gmm + loss_att           
                topk      = self.topk( y_lab_hat, y_lab.long() )               
                gmm       = self.gmm( z, y_lab )
                                
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                                
                # update
                self.logger_val.update( 
                    {'loss': loss.cpu().item(), 'loss_gmm': loss_gmm.cpu().item(), 'loss_bce': loss_bce.cpu().item(), 'loss_att':loss_att.cpu().item() },
                    {'topk': topk[0][0].cpu(), 'gmm': gmm.cpu().item() },    
                    batch_size,          
                    )

                if i % self.print_freq == 0:
                    self.logger_val.logger(
                        epoch, epoch, i,len(data_loader), 
                        batch_time, 
                        bplotter=False,
                        bavg=True, 
                        bsummary=False,
                        )

        #save validation loss
        self.vallosses = self.logger_val.info['loss']['loss'].avg
        acc = self.logger_val.info['metrics']['topk'].avg

        self.logger_val.logger(
            epoch, epoch, i, len(data_loader), 
            batch_time,
            bplotter=True,
            bavg=True, 
            bsummary=True,
            )

        #vizual_freq
        if epoch % self.view_freq == 0:

            att   = att[0,:,:,:].permute( 1,2,0 ).mean(dim=2)
            srf   = srf[0,:,:,:].permute( 1,2,0 ).sum(dim=2)  
            fmap  = fmap[0,:,:,:].permute( 1,2,0 ) 
                                    
            self.visheatmap.show('Image', x_img.data.cpu()[0].numpy()[0,:,:])           
            self.visheatmap.show('Image Attention',att.cpu().numpy().astype(np.float32) )
            self.visheatmap.show('Feature Map',srf.cpu().numpy().astype(np.float32) )
            self.visheatmap.show('Attention Map',fmap.cpu().numpy().astype(np.float32) )
            
        return acc

    
    def representation( self, dataloader, breal=True ):
        Y_labs = []
        Y_lab_hats = []
        Zs = []         
        self.net.eval()
        with torch.no_grad():
            for i_batch, sample in enumerate( tqdm(dataloader) ):

                if breal:                 
                    x_img, y_lab = sample['image'], sample['label']
                    y_lab = y_lab.argmax(dim=1)
                else:
                    x_org, x_img, y_mask, y_lab = sample
                    y_lab=y_lab[:,0]

                if self.cuda:
                    x_img = x_img.cuda()

                z, y_lab_hat, _,_,_ = self.net( x_img )
                Y_labs.append(y_lab)
                Y_lab_hats.append(y_lab_hat.data.cpu())
                Zs.append(z.data.cpu())
                
        Y_labs = np.concatenate( Y_labs, axis=0 )
        Y_lab_hats = np.concatenate( Y_lab_hats, axis=0 )
        Zs = np.concatenate( Zs, axis=0 )        
        return Y_labs, Y_lab_hats, Zs
    
    
    def __call__(self, image ):        
        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            x = image.cuda() if self.cuda else image    
            z, y_lab_hat, att, fmap, srf = self.net(x)                         
            y_lab_hat = F.softmax( y_lab_hat, dim=1 )            
        return z, y_lab_hat, att, fmap, srf


    def _create_loss(self, loss):

        # create loss
        if loss == 'attloss':            
            self.criterion_bce = nn.CrossEntropyLoss().cuda()
            self.criterion_att = nloss.Attloss()
            self.criterion_gmm = nloss.DGMMLoss( self.num_classes, cuda=self.cuda )            
        else:
            assert(False)

        self.s_loss = loss
    def create(self,
               arch,
               num_output_channels,
               num_input_channels,
               loss,
               lr,
               optimizer,
               lrsch,
               momentum=0.9,
               weight_decay=5e-4,
               pretrained=False,
               size_input=388,
               cascade_type='none'):
        """
        Create
        Args:
            -arch (string): architecture
            -num_output_channels, 
            -num_input_channels, 
            -loss (string):
            -lr (float): learning rate
            -optimizer (string) : 
            -lrsch (string): scheduler learning rate
            -pretrained (bool)
            
        """

        cfg_opt = {'momentum': momentum, 'weight_decay': weight_decay}
        cfg_scheduler = {'step_size': 100, 'gamma': 0.1}

        super(SegmentationNeuralNet, self).create(arch,
                                                  num_output_channels,
                                                  num_input_channels,
                                                  loss,
                                                  lr,
                                                  optimizer,
                                                  lrsch,
                                                  pretrained,
                                                  cfg_opt=cfg_opt,
                                                  cfg_scheduler=cfg_scheduler)
        self.size_input = size_input
        self.num_output_channels = num_output_channels
        self.cascade_type = cascade_type
        self.segs_per_forward = 7

        if self.cascade_type == 'none':
            self.step = self.default_step
        elif self.cascade_type == 'ransac':
            self.step = self.ransac_step
        elif self.cascade_type == 'ransac2':
            self.step = self.ransac_step2
        elif self.cascade_type == 'simple':
            self.step = self.cascate_step
        else:
            raise "Cascada not found"

        self.accuracy = nloss.Accuracy()
        if num_output_channels == 2:
            dice_dim = (1, )
        if num_output_channels == 4:
            dice_dim = (1, 2, 3)

        self.dice = nloss.Dice(dice_dim)

        # Set the graphic visualization
        self.logger_train = Logger('Train', ['loss'], ['accs', 'dices'],
                                   self.plotter)
        self.logger_val = Logger('Val  ', ['loss'], ['accs', 'dices', 'PQ'],
                                 self.plotter)

        self.visheatmap = gph.HeatMapVisdom(env_name=self.nameproject,
                                            heatsize=(256, 256))
        self.visimshow = gph.ImageVisdom(env_name=self.nameproject,
                                         imsize=(256, 256))
        if self.half_precision:
            self.scaler = torch.cuda.amp.GradScaler()
class SegmentationNeuralNet(NeuralNetAbstract):
    """
    Segmentation Neural Net 
    """
    def __init__(self,
                 patchproject,
                 nameproject,
                 no_cuda=True,
                 parallel=False,
                 seed=1,
                 print_freq=10,
                 gpu=0,
                 view_freq=1,
                 half_precision=False):
        """
        Initialization
            -patchproject (str): path project
            -nameproject (str):  name project
            -no_cuda (bool): system cuda (default is True)
            -parallel (bool)
            -seed (int)
            -print_freq (int)
            -gpu (int)
            -view_freq (in epochs)
        """

        super(SegmentationNeuralNet,
              self).__init__(patchproject, nameproject, no_cuda, parallel,
                             seed, print_freq, gpu)
        self.view_freq = view_freq
        self.half_precision = half_precision

    def create(self,
               arch,
               num_output_channels,
               num_input_channels,
               loss,
               lr,
               optimizer,
               lrsch,
               momentum=0.9,
               weight_decay=5e-4,
               pretrained=False,
               size_input=388,
               cascade_type='none'):
        """
        Create
        Args:
            -arch (string): architecture
            -num_output_channels, 
            -num_input_channels, 
            -loss (string):
            -lr (float): learning rate
            -optimizer (string) : 
            -lrsch (string): scheduler learning rate
            -pretrained (bool)
            
        """

        cfg_opt = {'momentum': momentum, 'weight_decay': weight_decay}
        cfg_scheduler = {'step_size': 100, 'gamma': 0.1}

        super(SegmentationNeuralNet, self).create(arch,
                                                  num_output_channels,
                                                  num_input_channels,
                                                  loss,
                                                  lr,
                                                  optimizer,
                                                  lrsch,
                                                  pretrained,
                                                  cfg_opt=cfg_opt,
                                                  cfg_scheduler=cfg_scheduler)
        self.size_input = size_input
        self.num_output_channels = num_output_channels
        self.cascade_type = cascade_type
        self.segs_per_forward = 7

        if self.cascade_type == 'none':
            self.step = self.default_step
        elif self.cascade_type == 'ransac':
            self.step = self.ransac_step
        elif self.cascade_type == 'ransac2':
            self.step = self.ransac_step2
        elif self.cascade_type == 'simple':
            self.step = self.cascate_step
        else:
            raise "Cascada not found"

        self.accuracy = nloss.Accuracy()
        if num_output_channels == 2:
            dice_dim = (1, )
        if num_output_channels == 4:
            dice_dim = (1, 2, 3)

        self.dice = nloss.Dice(dice_dim)

        # Set the graphic visualization
        self.logger_train = Logger('Train', ['loss'], ['accs', 'dices'],
                                   self.plotter)
        self.logger_val = Logger('Val  ', ['loss'], ['accs', 'dices', 'PQ'],
                                 self.plotter)

        self.visheatmap = gph.HeatMapVisdom(env_name=self.nameproject,
                                            heatsize=(256, 256))
        self.visimshow = gph.ImageVisdom(env_name=self.nameproject,
                                         imsize=(256, 256))
        if self.half_precision:
            self.scaler = torch.cuda.amp.GradScaler()

    def ransac_step(self,
                    inputs,
                    targets,
                    max_deep=4,
                    segs_per_forward=20,
                    src_c=3,
                    verbose=False):
        srcs = inputs[:, :src_c]
        segs = inputs[:, src_c:]
        lv_segs = segs  #.clone()

        first = True
        final_loss = 0.0
        for lv in range(max_deep):
            n_segs = segs.shape[1]
            new_segs = []
            actual_c = self.segs_per_forward**(max_deep - lv)
            if verbose: print(segs.shape, actual_c)
            actual_seg_ids = np.random.choice(range(n_segs), size=actual_c)
            step_segs = segs[:, actual_seg_ids]

            for idx in range(0, actual_c, self.segs_per_forward):
                mini_inp = torch.cat(
                    (srcs, step_segs[:, idx:idx + self.segs_per_forward]),
                    dim=1)
                mini_out = self.net(mini_inp)
                ## calculate loss
                final_loss += self.criterion(mini_out, targets) * 1
                new_segs.append(mini_out.argmax(1, keepdim=True))

                if verbose:
                    print(mini_inp.shape, idx, idx + self.segs_per_forward,
                          actual_loss.item())

            segs = torch.cat(new_segs, dim=1)

        return final_loss, mini_out

    def ransac_step2(self,
                     inputs,
                     targets,
                     max_deep=4,
                     n_times=10,
                     segs_per_forward=20,
                     src_c=3,
                     verbose=False):
        srcs = inputs[:, :src_c]
        segs = inputs[:, src_c:]

        first = True
        final_loss = 0.0
        for lv in range(n_times):
            n_segs = segs.shape[1]

            actual_seg_ids = np.random.choice(range(n_segs),
                                              size=self.segs_per_forward)
            step_segs = segs[:, actual_seg_ids]

            mini_inp = torch.cat((srcs, step_segs), dim=1)
            mini_out = self.net(mini_inp)
            ## calculate loss
            final_loss += self.criterion(mini_out, targets) * 1

            segs = torch.cat((segs, mini_out.argmax(1, keepdim=True)), dim=1)

        return final_loss, mini_out

    def cascate_step(self,
                     inputs,
                     targets,
                     segs_per_forward=20,
                     src_c=3,
                     verbose=False):

        srcs = inputs[:, :src_c]
        segs = inputs[:, src_c:]
        lv_segs = segs.clone()

        final_loss = 0.0
        n_segs = lv_segs.shape[1]
        actual_seg_ids = np.random.choice(range(n_segs),
                                          size=n_segs,
                                          replace=False)
        lv_segs = lv_segs[:, actual_seg_ids]

        while n_segs > 1:

            if verbose: print(n_segs)

            inputs_seg = lv_segs[:, :self.segs_per_forward]
            inputs_seg_ids = np.random.choice(
                range(inputs_seg.shape[1]),
                size=self.segs_per_forward,
                replace=inputs_seg.shape[1] < self.segs_per_forward)
            inputs_seg = inputs_seg[:, inputs_seg_ids]

            mini_inp = torch.cat((srcs, inputs_seg), dim=1)
            mini_out = self.net(mini_inp)
            ## calculate loss
            final_loss += self.criterion(mini_out, targets)

            if verbose:
                print(mini_inp.shape, self.segs_per_forward,
                      actual_loss.item())
            lv_segs = torch.cat((lv_segs[:, self.segs_per_forward:],
                                 mini_out.argmax(1, keepdim=True)),
                                dim=1)
            n_segs = lv_segs.shape[1]
        return final_loss, mini_out

    def default_step(self, inputs, targets):
        outputs = self.net(inputs)
        loss = self.criterion(outputs, targets)
        return loss, outputs

    def training(self, data_loader, epoch=0):
        #reset logger
        self.logger_train.reset()
        data_time = AverageMeter()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.train()

        end = time.time()
        for i, sample in enumerate(data_loader):

            # measure data loading time
            data_time.update(time.time() - end)
            # get data (image, label, weight)
            inputs, targets = sample['image'], sample['label']
            weights = None
            if 'weight' in sample.keys():
                weights = sample['weight']

            batch_size = inputs.shape[0]

            if self.cuda:
                inputs = inputs.cuda()
                targets = targets.cuda()
                if type(weights) is not type(None):
                    weights = weights.cuda()

            # fit (forward)
            if self.half_precision:
                with torch.cuda.amp.autocast():

                    loss, outputs = self.step(inputs, targets)

                    self.optimizer.zero_grad()
                    self.scaler.scale(loss * batch_size).backward()
                    self.scaler.step(self.optimizer)
                    self.scaler.update()

            else:
                loss, outputs = self.step(inputs, targets)

                self.optimizer.zero_grad()
                (batch_size * loss).backward()  #batch_size
                self.optimizer.step()

            accs = self.accuracy(outputs, targets)
            dices = self.dice(outputs, targets)
            #pq    = metrics.pq_metric(outputs.cpu().detach().numpy(), targets.cpu().detach().numpy())
            #pq, n_cells  = metrics.pq_metric(targets, outputs)

            # update
            self.logger_train.update(
                {'loss': loss.item()},
                {
                    'accs': accs.item(),
                    #'PQ': pq,
                    'dices': dices.item()
                },
                batch_size,
            )

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                self.logger_train.logger(
                    epoch,
                    epoch + float(i + 1) / len(data_loader),
                    i,
                    len(data_loader),
                    batch_time,
                )

    def evaluate(self, data_loader, epoch=0):
        pq_sum = 0
        total_cells = 0
        self.logger_val.reset()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, sample in enumerate(data_loader):

                # get data (image, label)
                inputs, targets = sample['image'], sample['label']

                weights = None
                if 'weight' in sample.keys():
                    weights = sample['weight']
                #inputs, targets = sample['image'], sample['label']

                batch_size = inputs.shape[0]

                #print(inputs.shape)

                if self.cuda:
                    inputs = inputs.cuda()
                    targets = targets.cuda()
                    if type(weights) is not type(None):
                        weights = weights.cuda()
                #print(inputs.shape)

                # fit (forward)
                if self.half_precision:
                    with torch.cuda.amp.autocast():
                        loss, outputs = self.step(inputs, targets)
                else:
                    loss, outputs = self.step(inputs, targets)

                # measure accuracy and record loss

                accs = self.accuracy(outputs, targets)
                dices = self.dice(outputs, targets)

                #targets_np = targets[0][1].cpu().numpy().astype(int)
                if epoch == 0:
                    pq = 0
                    n_cells = 1
                else:
                    #pq, n_cells    = metrics.pq_metric(targets, outputs)
                    if False:  #self.skip_background:
                        out_shape = outputs.shape
                        zeros = torch.zeros((out_shape[0], 1, out_shape[2],
                                             out_shape[3])).cuda()
                        outputs = torch.cat([zeros, outputs], 1)

                    all_metrics, n_cells, _ = metrics.get_metrics(
                        targets, outputs)
                    pq = all_metrics['pq']

                pq_sum += pq * n_cells
                total_cells += n_cells

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # update
                #print(loss.item(), accs, dices, batch_size)
                self.logger_val.update(
                    {'loss': loss.item()},
                    {
                        'accs': accs.item(),
                        'PQ': (pq_sum / total_cells) if total_cells > 0 else 0,
                        'dices': dices.item()
                    },
                    batch_size,
                )

                if i % self.print_freq == 0:
                    self.logger_val.logger(
                        epoch,
                        epoch,
                        i,
                        len(data_loader),
                        batch_time,
                        bplotter=False,
                        bavg=True,
                        bsummary=False,
                    )

        #save validation loss
        if total_cells == 0:
            pq_weight = 0
        else:
            pq_weight = pq_sum / total_cells

        print(f"PQ: {pq_weight:0.4f}, {pq_sum:0.4f}, {total_cells}")

        self.vallosses = self.logger_val.info['loss']['loss'].avg
        acc = self.logger_val.info['metrics']['accs'].avg
        #pq = pq_weight

        self.logger_val.logger(
            epoch,
            epoch,
            i,
            len(data_loader),
            batch_time,
            bplotter=True,
            bavg=True,
            bsummary=True,
        )

        #vizual_freq
        if epoch % self.view_freq == 0:

            prob = F.softmax(outputs.cpu().float(), dim=1)
            prob = prob.data[0]
            maxprob = torch.argmax(prob, 0)

            self.visheatmap.show('Label',
                                 targets.data.cpu()[0].numpy()[1, :, :])
            #self.visheatmap.show('Weight map', weights.data.cpu()[0].numpy()[0,:,:])
            self.visheatmap.show('Image',
                                 inputs.data.cpu()[0].numpy()[0, :, :])
            self.visheatmap.show('Max prob',
                                 maxprob.cpu().numpy().astype(np.float32))
            for k in range(prob.shape[0]):
                self.visheatmap.show('Heat map {}'.format(k),
                                     prob.cpu()[k].numpy())

        print(f"End Val: wPQ{pq_weight}")
        return pq_weight

    def test(self, data_loader):

        masks = []
        ids = []
        k = 0

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, sample in enumerate(tqdm(data_loader)):

                # get data (image, label)
                inputs, meta = sample['image'], sample['metadata']
                idd = meta[:, 0]
                x = inputs.cuda() if self.cuda else inputs

                # fit (forward)
                yhat = self.net(x)
                yhat = F.softmax(yhat, dim=1)
                yhat = pytutils.to_np(yhat)

                masks.append(yhat)
                ids.append(idd)

        return ids, masks

    def __call__(self, image):
        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            x = image.cuda() if self.cuda else image
            yhat = self.net(x)
            yhat = F.softmax(yhat, dim=1)
            #yhat = pytutils.to_np(yhat).transpose(2,3,1,0)[...,0]
        return yhat

    def _create_model(self, arch, num_output_channels, num_input_channels,
                      pretrained):
        """
        Create model
            -arch (string): select architecture
            -num_classes (int)
            -num_channels (int)
            -pretrained (bool)
        """

        self.net = None

        #--------------------------------------------------------------------------------------------
        # select architecture
        #--------------------------------------------------------------------------------------------
        #kw = {'num_classes': num_output_channels, 'num_channels': num_input_channels, 'pretrained': pretrained}

        kw = {
            'num_classes': num_output_channels,
            'in_channels': num_input_channels,
            'pretrained': pretrained
        }
        self.net = nnmodels.__dict__[arch](**kw)

        self.s_arch = arch
        self.num_output_channels = num_output_channels
        self.num_input_channels = num_input_channels

        if self.cuda == True:
            self.net.cuda()
        if self.parallel == True and self.cuda == True:
            self.net = nn.DataParallel(self.net,
                                       device_ids=range(
                                           torch.cuda.device_count()))

    def _create_loss(self, loss):

        # create loss
        if loss == 'wmce':  # Not tested
            self.criterion = nloss.WeightedMCEloss()
        elif loss == 'bdice':  # Fail
            self.criterion = nloss.BDiceLoss()
        elif loss == 'wbdice':  # Fail
            self.criterion = nloss.WeightedBDiceLoss()
        elif loss == 'wmcedice':  # Fail
            self.criterion = nloss.WeightedMCEDiceLoss()
        elif loss == 'mcedice':  # Fail
            self.criterion = nloss.MCEDiceLoss()
        elif loss == 'bce':  # Pass
            self.criterion = nloss.BCELoss()
        elif loss == 'bce2c':  # Pass
            self.criterion = nloss.BCELoss2c()
        elif loss == 'mce':  # Pass
            self.criterion = nloss.MCELoss()
        elif loss == 'wbce':
            self.criterion = nloss.WeightedBCELoss()
        elif loss == 'wce':  # Pass
            self.criterion = nloss.WeightedCrossEntropyLoss()
        elif loss == 'wfocalce':  # Pass
            self.criterion = nloss.WeightedCEFocalLoss()
        elif loss == 'focaldice':  # Pass
            self.criterion = nloss.FocalDiceLoss()
        elif loss == 'wfocaldice':  # Pass
            self.criterion = nloss.WeightedFocalDiceLoss()
        elif loss == 'dice':  # FAIL
            self.criterion = nloss.DiceLoss()
        elif loss == 'msedice':  # FAIL
            self.criterion = nloss.MSEDICELoss()
        elif loss == 'mcc':  # FAIL
            self.criterion = nloss.MCCLoss()
        elif loss == 'mdice':  # FAIL
            self.criterion = nloss.MDiceLoss()
        elif loss == 'wcefd':
            self.criterion = nloss.WeightedCEFocalDice()
        elif loss == 'jreg':
            if self.num_output_channels == 2:
                lambda_dict = {
                    '0': {
                        '0': '1',
                        '1': '0.5'
                    },
                    '1': {
                        '0': '0.5',
                        '1': '1'
                    }
                }
            if self.num_output_channels == 4:
                lambda_dict = {
                    '0': {
                        '0': '1',
                        '1': '0.5',
                        '2': '0.5',
                        '3': '0.5'
                    },
                    '1': {
                        '0': '0.5',
                        '1': '1',
                        '2': '0.5',
                        '3': '0.5'
                    },
                    '2': {
                        '0': '0.5',
                        '1': '0.5',
                        '2': '1',
                        '3': '0.5'
                    },
                    '3': {
                        '0': '0.5',
                        '1': '0.5',
                        '2': '0.5',
                        '3': '1'
                    }
                }

            self.criterion = nloss.WCE_J_SIMPL(lambda_dict=lambda_dict)
        else:
            assert (False)

        self.s_loss = loss