Esempio n. 1
0
    def training(self, data_loader, epoch=0):

        self.logger_train.reset()
        data_time = AverageMeter()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.train()

        end = time.time()
        for i, sample in enumerate(data_loader):

            # measure data loading time
            data_time.update(time.time() - end)
            # get data (image, label)
            inputs, targets = sample['image'], pytutils.argmax(sample['label'])
            batch_size = inputs.size(0)

            if self.cuda:
                targets = targets.cuda(non_blocking=True)
                inputs_var = Variable(inputs.cuda(), requires_grad=False)
                targets_var = Variable(targets.cuda(), requires_grad=False)
            else:
                inputs_var = Variable(inputs, requires_grad=False)
                targets_var = Variable(targets, requires_grad=False)

            # fit (forward)
            outputs = self.net(inputs_var)

            # measure accuracy and record loss
            loss = self.criterion(outputs, targets_var)
            pred = self.accuracy(outputs.data, targets)

            # optimizer
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # update
            self.logger_train.update(
                {'loss': loss.data[0]},
                dict(
                    zip(self.metrics_name,
                        [pred[p][0] for p in range(len(self.metrics_name))])),
                batch_size,
            )

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                self.logger_train.logger(
                    epoch,
                    epoch + float(i + 1) / len(data_loader),
                    i,
                    len(data_loader),
                    batch_time,
                )
Esempio n. 2
0
    def training(self, data_loader, epoch=0):        

        #reset logger
        self.logger_train.reset()
        data_time = AverageMeter()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.train()

        end = time.time()
        for i, (x_org, x_img, y_mask, meta ) in enumerate(data_loader):
            
            # measure data loading time
            data_time.update(time.time() - end)
            batch_size = x_img.shape[0]
            
            y_lab = meta[:,0]
            y_theta   = meta[:,1:].view(-1, 2, 3)            

            if self.cuda:
                x_org   = x_org.cuda()
                x_img   = x_img.cuda() 
                y_mask  = y_mask.cuda() 
                y_lab   = y_lab.cuda()
                y_theta = y_theta.cuda()
            
            # fit (forward)            
            y_lab_hat, att, theta, _, _, _ = self.net( x_img, x_img*y_mask[:,1,...].unsqueeze(dim=1) )                
            
            # measure accuracy and record loss           
            loss_bce  = self.criterion_bce( y_lab_hat, y_lab.long() )            
            loss_att  = self.criterion_att( x_org, y_mask, att )
            loss_stn  = self.criterion_stn( x_org, y_theta, theta )
            
            loss = loss_bce + loss_att + loss_stn        
            topk  = self.topk( y_lab_hat, y_lab.long() )
                     
            
            # optimizer
            self.optimizer.zero_grad()
            (loss*batch_size).backward()
            self.optimizer.step()
            
            # update
            self.logger_train.update(
                {'loss': loss.cpu().item(), 'loss_bce': loss_bce.cpu().item(), 'loss_att':loss_att.cpu().item(), 'loss_stn':loss_stn.cpu().item() },
                {'topk': topk[0][0].cpu() },
                batch_size,
                )
            
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            
            if i % self.print_freq == 0:  
                self.logger_train.logger( epoch, epoch + float(i+1)/len(data_loader), i, len(data_loader), batch_time,   )
    def training(self, data_loader, epoch=0):

        self.logger_train.reset()
        data_time = AverageMeter()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.train()

        end = time.time()
        for i, sample in enumerate(data_loader):

            # measure data loading time
            data_time.update(time.time() - end)
            # get data (image, label)
            x, y = sample['image'], sample['label'].argmax(1).long()
            batch_size = x.size(0)

            if self.cuda:
                x = x.cuda()
                y = y.cuda()

            # fit (forward)
            outputs = self.net(x)

            # measure accuracy and record loss
            loss = self.criterion(outputs, y.long())
            pred = self.accuracy(outputs.data, y)

            # optimizer
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # update
            self.logger_train.update(
                {'loss': loss.item()},
                dict(
                    zip(self.metrics_name, [
                        pred[p].item() for p in range(len(self.metrics_name))
                    ])),
                batch_size,
            )

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                self.logger_train.logger(
                    epoch,
                    epoch + float(i + 1) / len(data_loader),
                    i,
                    len(data_loader),
                    batch_time,
                )
    def training(self, data_loader, epoch=0):

        self.logger_train.reset()
        data_time = AverageMeter()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.train()

        end = time.time()
        for i, (iD, image, prob) in enumerate(data_loader):

            # measure data loading time
            data_time.update(time.time() - end)
            x, y = image, prob
            batch_size = x.size(0)

            if self.cuda:
                x = x.cuda()
                y = y.cuda()

            # fit (forward)
            yhat = self.net(x)

            # measure accuracy and record loss
            loss = self.criterion(yhat, y.float())
            pred = self.accuracy(yhat.data, y.data)
            f1 = self.f_score(yhat.data, y.data)

            # optimizer
            self.optimizer.zero_grad()
            (loss).backward()
            self.optimizer.step()

            # update
            self.logger_train.update(
                {'loss': loss.data[0]},
                {
                    'acc': pred.data[0],
                    'f1': f1.data[0]
                },
                batch_size,
            )

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                self.logger_train.logger(
                    epoch,
                    epoch + float(i + 1) / len(data_loader),
                    i,
                    len(data_loader),
                    batch_time,
                )
    def representation(self, data_loader):
        """"
        Representation
            -data_loader: simple data loader for image
        """

        # switch to evaluate mode
        self.net.eval()

        n = len(data_loader) * data_loader.batch_size
        k = 0

        # embebed features
        embX = np.zeros([n, self.net.dim])
        embY = np.zeros([n, 1])

        batch_time = AverageMeter()
        end = time.time()
        for i, sample in enumerate(data_loader):

            # get data (image, label)
            x, y = sample['image'], sample['label'].argmax(1).long()
            x = x.cuda() if self.cuda else x

            # representation
            emb = self.net.representation(x)
            emb = pytutils.to_np(emb)
            for j in range(emb.shape[0]):
                embX[k, :] = emb[j, :]
                embY[k] = y[j]
                k += 1

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            print(
                'Representation: |{:06d}/{:06d}||{batch_time.val:.3f} ({batch_time.avg:.3f})|'
                .format(i, len(data_loader), batch_time=batch_time))

        embX = embX[:k, :]
        embY = embY[:k]

        return embX, embY
    def evaluate(self, data_loader, epoch=0):

        self.logger_val.reset()
        self.cnf.reset()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, sample in enumerate(data_loader):

                # get data (image, label)
                x, y = sample['image'], sample['label'].argmax(1).long()
                batch_size = x.size(0)

                if self.cuda:
                    x = x.cuda()
                    y = y.cuda()

                # fit (forward)
                outputs = self.net(x)

                # measure accuracy and record loss
                loss = self.criterion(outputs, y)
                pred = self.accuracy(outputs.data, y.data)
                self.cnf.add(outputs.argmax(1), y)

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # update
                self.logger_val.update(
                    {'loss': loss.item()},
                    dict(
                        zip(self.metrics_name, [
                            pred[p].item()
                            for p in range(len(self.metrics_name))
                        ])),
                    batch_size,
                )

                if i % self.print_freq == 0:
                    self.logger_val.logger(
                        epoch,
                        epoch,
                        i,
                        len(data_loader),
                        batch_time,
                        bplotter=False,
                        bavg=True,
                        bsummary=False,
                    )

        # save validation loss
        self.vallosses = self.logger_val.info['loss']['loss'].avg
        acc = self.logger_val.info['metrics']['top1'].avg

        self.logger_val.logger(
            epoch,
            epoch,
            i,
            len(data_loader),
            batch_time,
            bplotter=True,
            bavg=True,
            bsummary=True,
        )

        print('Confusion Matriz')
        print(self.cnf.value(), flush=True)
        print('\n')

        self.visheatmap.show('Confusion Matriz', self.cnf.value())
        return acc
Esempio n. 7
0
    def evaluate(self, data_loader, epoch=0, *args):
        """
        evaluate on validation dataset
        :param data_loader: which data_loader to use
        :param epoch: current epoch
        :return:
            acc: average accuracy on data_loader
        """
        # reset loader
        self.logger_val.reset()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, sample in enumerate(data_loader):

                # get data (image, label)
                if self.breal == 'real':
                    x_img, y_lab = sample["image"], sample["label"]
                    y_lab = y_lab.argmax(dim=1)

                    if self.cuda:
                        x_img = x_img.cuda()
                        y_lab = y_lab.cuda()
                    x_org = x_img.clone().detach()
                else:
                    x_org, x_img, y_mask, meta = sample

                    y_lab = meta[:, 0]
                    y_theta = meta[:, 1:].view(-1, 2, 3)

                    if self.cuda:
                        x_org = x_org.cuda()
                        x_img = x_img.cuda()
                        y_mask = y_mask.cuda()
                        y_lab = y_lab.cuda()
                        y_theta = y_theta.cuda()

                # fit (forward)
                # print("x_img size", x_img.size())
                y_lab_hat = self.net(x_img)

                # measure accuracy and record loss
                loss_bce = self.criterion_bce(y_lab_hat, y_lab.long())
                loss = loss_bce
                topk = self.topk(y_lab_hat, y_lab.long())

                batch_size = x_img.shape[0]

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # update
                self.logger_val.update(
                    {
                        'loss': loss.cpu().item(),
                        'loss_bce': loss_bce.cpu().item()
                    },
                    {'topk': topk[0][0].cpu()},
                    batch_size,
                )

                # print the result in certain print frequency when iterating batches
                if i % self.print_freq == 0:
                    self.logger_val.logger(
                        epoch,
                        epoch,
                        i,
                        len(data_loader),
                        batch_time,
                        bplotter=False,
                        bavg=True,
                        bsummary=False,
                    )

        #save validation loss and accuracy
        self.vallosses = self.logger_val.info['loss']['loss'].avg
        acc = self.logger_val.info['metrics']['topk'].avg

        # print the average loss and accuracy after one iteration
        self.logger_val.logger(
            epoch,
            epoch,
            i,
            len(data_loader),
            batch_time,
            bplotter=True,
            bavg=True,
            bsummary=True,
        )

        return acc
Esempio n. 8
0
    def training(self, data_loader, epoch=0, *args):
        #reset logger
        self.logger_train.reset()
        data_time = AverageMeter()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.train()

        end = time.time()
        for i, sample in enumerate(data_loader):

            # measure data loading time
            data_time.update(time.time() - end)
            # if dataset is real
            if self.breal == 'real':
                x_img, y_lab = sample['image'], sample['label']
                y_lab = y_lab.argmax(dim=1)

                if self.cuda:
                    x_img = x_img.cuda()
                    y_lab = y_lab.cuda()

                x_org = x_img.clone().detach()

            else:
                # if dataset is synthetic
                x_org, x_img, y_mask, meta = sample

                y_lab = meta[:, 0]
                y_theta = meta[:, 1:].view(-1, 2, 3)

                if self.cuda:
                    x_org = x_org.cuda()
                    x_img = x_img.cuda()
                    y_mask = y_mask.cuda()
                    y_lab = y_lab.cuda()
                    y_theta = y_theta.cuda()

            # fit (forward)
            y_lab_hat = self.net(x_img)

            # calculate classification loss
            loss_bce = self.criterion_bce(y_lab_hat, y_lab.long())
            loss = loss_bce
            # accuracy of choosing top k predicted classes
            topk = self.topk(y_lab_hat, y_lab.long())

            batch_size = x_img.shape[0]

            # optimizer
            self.optimizer.zero_grad()
            (loss * batch_size).backward()
            self.optimizer.step()

            # update
            self.logger_train.update(
                {
                    'loss': loss.cpu().item(),
                    'loss_bce': loss_bce.cpu().item()
                },
                {'topk': topk[0][0].cpu()},
                batch_size,
            )

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                self.logger_train.logger(
                    epoch,
                    epoch + float(i + 1) / len(data_loader),
                    i,
                    len(data_loader),
                    batch_time,
                )
    def evaluate(self, data_loader, epoch=0):

        self.logger_val.reset()
        #self.cnf.reset()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, (iD, image, prob) in enumerate(data_loader):

                # get data (image, label)
                x, y = image, prob  #.argmax(1).long()
                batch_size = x.size(0)

                if self.cuda:
                    x = x.cuda()
                    y = y.cuda()

                # fit (forward)
                yhat = self.net(x)

                # measure accuracy and record loss
                loss = self.criterion(yhat, y.float())
                pred = self.accuracy(yhat.data, y.data)
                f1 = self.f_score(yhat.data, y.data)

                #self.cnf.add( outputs.argmax(1), y )

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # update
                self.logger_val.update(
                    {'loss': loss.data[0]},
                    {
                        'acc': pred.data[0],
                        'f1': f1.data[0]
                    },
                    batch_size,
                )

                if i % self.print_freq == 0:
                    self.logger_val.logger(
                        epoch,
                        epoch,
                        i,
                        len(data_loader),
                        batch_time,
                        bplotter=False,
                        bavg=True,
                        bsummary=False,
                    )

        #save validation loss
        self.vallosses = self.logger_val.info['loss']['loss'].avg
        acc = self.logger_val.info['metrics']['acc'].avg

        self.logger_val.logger(
            epoch,
            epoch,
            i,
            len(data_loader),
            batch_time,
            bplotter=True,
            bavg=True,
            bsummary=True,
        )

        #print('Confusion Matriz')
        #print(self.cnf.value(), flush=True)
        #print('\n')
        #self.visheatmap.show('Confusion Matriz', self.cnf.value())

        return acc
Esempio n. 10
0
    def evaluate(self, data_loader, epoch=0):

        self.logger_val.reset()
        self.cnf.reset()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, sample in enumerate(data_loader):

                # get data (image, label)
                inputs, targets = sample['image'], pytutils.argmax(
                    sample['label'])
                batch_size = inputs.size(0)

                if self.cuda:
                    targets = targets.cuda(non_blocking=True)
                    inputs_var = Variable(inputs.cuda(),
                                          requires_grad=False,
                                          volatile=True)
                    targets_var = Variable(targets.cuda(),
                                           requires_grad=False,
                                           volatile=True)
                else:
                    inputs_var = Variable(inputs,
                                          requires_grad=False,
                                          volatile=True)
                    targets_var = Variable(targets,
                                           requires_grad=False,
                                           volatile=True)

                # fit (forward)
                outputs = self.net(inputs_var)

                # measure accuracy and record loss
                loss = self.criterion(outputs, targets_var)
                pred = self.accuracy(outputs.data, targets)
                self.cnf.add(outputs.argmax(1), targets_var)

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # update
                self.logger_val.update(
                    {'loss': loss.data[0]},
                    dict(
                        zip(self.metrics_name, [
                            pred[p][0] for p in range(len(self.metrics_name))
                        ])),
                    batch_size,
                )

                if i % self.print_freq == 0:
                    self.logger_val.logger(
                        epoch,
                        epoch,
                        i,
                        len(data_loader),
                        batch_time,
                        bplotter=False,
                        bavg=True,
                        bsummary=False,
                    )

        #save validation loss
        self.vallosses = self.logger_val.info['loss']['loss'].avg
        acc = self.logger_val.info['metrics']['top1'].avg

        self.logger_val.logger(
            epoch,
            epoch,
            i,
            len(data_loader),
            batch_time,
            bplotter=True,
            bavg=True,
            bsummary=True,
        )

        print('Confusion Matriz')
        print(self.cnf.value(), flush=True)
        print('\n')

        self.visheatmap.show('Confusion Matriz', self.cnf.value())
        return acc
    def training(self, data_loader, epoch=0):

        #reset logger
        self.logger_train.reset()
        data_time = AverageMeter()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.train()

        end = time.time()
        for i, sample in enumerate(data_loader):

            # measure data loading time
            data_time.update(time.time() - end)
            # get data (image, label, weight)
            inputs, targets, weights = sample['image'], sample[
                'label'], sample['weight']
            batch_size = inputs.size(0)

            if self.cuda:
                targets = targets.cuda(non_blocking=True)
                inputs_var = Variable(inputs.cuda(), requires_grad=False)
                targets_var = Variable(targets.cuda(), requires_grad=False)
                weights_var = Variable(weights.cuda(), requires_grad=False)
            else:
                inputs_var = Variable(inputs, requires_grad=False)
                targets_var = Variable(targets, requires_grad=False)
                weights_var = Variable(weights, requires_grad=False)

            # fit (forward)
            outputs = self.net(inputs_var)

            # measure accuracy and record loss
            loss = self.criterion(outputs, targets_var, weights_var)
            accs = self.accuracy(outputs, targets_var)
            dices = self.dice(outputs, targets_var)

            # optimizer
            self.optimizer.zero_grad()
            (loss * batch_size).backward()
            self.optimizer.step()

            # update
            self.logger_train.update(
                {'loss': loss.data[0]},
                {
                    'accs': accs,
                    'dices': dices
                },
                batch_size,
            )

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                self.logger_train.logger(
                    epoch,
                    epoch + float(i + 1) / len(data_loader),
                    i,
                    len(data_loader),
                    batch_time,
                )
    def evaluate(self, data_loader, epoch=0):

        # reset loader
        self.logger_val.reset()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, sample in enumerate(data_loader):

                # get data (image, label)
                inputs, targets, weights = sample['image'], sample[
                    'label'], sample['weight']
                batch_size = inputs.size(0)

                if self.cuda:
                    targets = targets.cuda(non_blocking=True)
                    inputs_var = Variable(inputs.cuda(),
                                          requires_grad=False,
                                          volatile=True)
                    targets_var = Variable(targets.cuda(),
                                           requires_grad=False,
                                           volatile=True)
                    weights_var = Variable(weights.cuda(),
                                           requires_grad=False,
                                           volatile=True)
                else:
                    inputs_var = Variable(inputs,
                                          requires_grad=False,
                                          volatile=True)
                    targets_var = Variable(targets,
                                           requires_grad=False,
                                           volatile=True)
                    weights_var = Variable(weights,
                                           requires_grad=False,
                                           volatile=True)

                # fit (forward)
                outputs = self.net(inputs_var)

                # measure accuracy and record loss
                loss = self.criterion(outputs, targets_var, weights_var)
                accs = self.accuracy(outputs, targets_var)
                dices = self.dice(outputs, targets_var)

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # update
                self.logger_val.update(
                    {'loss': loss.data[0]},
                    {
                        'accs': accs,
                        'dices': dices
                    },
                    batch_size,
                )

                if i % self.print_freq == 0:
                    self.logger_val.logger(
                        epoch,
                        epoch,
                        i,
                        len(data_loader),
                        batch_time,
                        bplotter=False,
                        bavg=True,
                        bsummary=False,
                    )

        #save validation loss
        self.vallosses = self.logger_val.info['loss']['loss'].avg
        acc = self.logger_val.info['metrics']['accs'].avg

        self.logger_val.logger(
            epoch,
            epoch,
            i,
            len(data_loader),
            batch_time,
            bplotter=True,
            bavg=True,
            bsummary=True,
        )

        #vizual_freq
        if epoch % self.view_freq == 0:

            prob = F.softmax(outputs, dim=1)
            prob = prob.data[0]
            _, maxprob = torch.max(prob, 0)

            self.visheatmap.show('Label',
                                 targets_var.data.cpu()[0].numpy()[1, :, :])
            self.visheatmap.show('Weight map',
                                 weights_var.data.cpu()[0].numpy()[0, :, :])
            self.visheatmap.show('Image',
                                 inputs_var.data.cpu()[0].numpy()[0, :, :])
            self.visheatmap.show('Max prob',
                                 maxprob.cpu().numpy().astype(np.float32))
            for k in range(prob.shape[0]):
                self.visheatmap.show('Heat map {}'.format(k),
                                     prob.cpu()[k].numpy())

        return acc
Esempio n. 13
0
    def evaluate(self, data_loader, epoch=0):
        
        # reset loader
        self.logger_val.reset()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, (x_org, x_img, y_mask, meta) in enumerate(data_loader):
                
                # get data (image, label)
                batch_size = x_img.shape[0]    
                                               
                y_lab = meta[:,0]
                y_theta   = meta[:,1:].view(-1, 2, 3)
                                
                if self.cuda:
                    x_org   = x_org.cuda()
                    x_img   = x_img.cuda()
                    y_mask  = y_mask.cuda()
                    y_lab   = y_lab.cuda()
                    y_theta = y_theta.cuda()
                
                # fit (forward)            
                z, y_lab_hat, att, fmap, srf  = self.net( x_img )                 
                
                # measure accuracy and record loss       
                loss_bce  = self.criterion_bce( y_lab_hat, y_lab.long() )
                loss_gmm  = self.criterion_gmm( z, y_lab )
                loss_att  = self.criterion_att( x_org, y_mask, att  )
                loss      = loss_bce + loss_gmm + loss_att           
                topk      = self.topk( y_lab_hat, y_lab.long() )               
                gmm       = self.gmm( z, y_lab )
                                
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                                
                # update
                self.logger_val.update( 
                    {'loss': loss.cpu().item(), 'loss_gmm': loss_gmm.cpu().item(), 'loss_bce': loss_bce.cpu().item(), 'loss_att':loss_att.cpu().item() },
                    {'topk': topk[0][0].cpu(), 'gmm': gmm.cpu().item() },    
                    batch_size,          
                    )

                if i % self.print_freq == 0:
                    self.logger_val.logger(
                        epoch, epoch, i,len(data_loader), 
                        batch_time, 
                        bplotter=False,
                        bavg=True, 
                        bsummary=False,
                        )

        #save validation loss
        self.vallosses = self.logger_val.info['loss']['loss'].avg
        acc = self.logger_val.info['metrics']['topk'].avg

        self.logger_val.logger(
            epoch, epoch, i, len(data_loader), 
            batch_time,
            bplotter=True,
            bavg=True, 
            bsummary=True,
            )

        #vizual_freq
        if epoch % self.view_freq == 0:

            att   = att[0,:,:,:].permute( 1,2,0 ).mean(dim=2)
            srf   = srf[0,:,:,:].permute( 1,2,0 ).sum(dim=2)  
            fmap  = fmap[0,:,:,:].permute( 1,2,0 ) 
                                    
            self.visheatmap.show('Image', x_img.data.cpu()[0].numpy()[0,:,:])           
            self.visheatmap.show('Image Attention',att.cpu().numpy().astype(np.float32) )
            self.visheatmap.show('Feature Map',srf.cpu().numpy().astype(np.float32) )
            self.visheatmap.show('Attention Map',fmap.cpu().numpy().astype(np.float32) )
            
        return acc
    def evaluate(self, data_loader, epoch=0):
        pq_sum = 0
        total_cells = 0
        self.logger_val.reset()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.eval()
        with torch.no_grad():
            end = time.time()
            for i, sample in enumerate(data_loader):

                # get data (image, label)
                inputs, targets = sample['image'], sample['label']

                weights = None
                if 'weight' in sample.keys():
                    weights = sample['weight']
                #inputs, targets = sample['image'], sample['label']

                batch_size = inputs.shape[0]

                #print(inputs.shape)

                if self.cuda:
                    inputs = inputs.cuda()
                    targets = targets.cuda()
                    if type(weights) is not type(None):
                        weights = weights.cuda()
                #print(inputs.shape)

                # fit (forward)
                if self.half_precision:
                    with torch.cuda.amp.autocast():
                        loss, outputs = self.step(inputs, targets)
                else:
                    loss, outputs = self.step(inputs, targets)

                # measure accuracy and record loss

                accs = self.accuracy(outputs, targets)
                dices = self.dice(outputs, targets)

                #targets_np = targets[0][1].cpu().numpy().astype(int)
                if epoch == 0:
                    pq = 0
                    n_cells = 1
                else:
                    #pq, n_cells    = metrics.pq_metric(targets, outputs)
                    if False:  #self.skip_background:
                        out_shape = outputs.shape
                        zeros = torch.zeros((out_shape[0], 1, out_shape[2],
                                             out_shape[3])).cuda()
                        outputs = torch.cat([zeros, outputs], 1)

                    all_metrics, n_cells, _ = metrics.get_metrics(
                        targets, outputs)
                    pq = all_metrics['pq']

                pq_sum += pq * n_cells
                total_cells += n_cells

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # update
                #print(loss.item(), accs, dices, batch_size)
                self.logger_val.update(
                    {'loss': loss.item()},
                    {
                        'accs': accs.item(),
                        'PQ': (pq_sum / total_cells) if total_cells > 0 else 0,
                        'dices': dices.item()
                    },
                    batch_size,
                )

                if i % self.print_freq == 0:
                    self.logger_val.logger(
                        epoch,
                        epoch,
                        i,
                        len(data_loader),
                        batch_time,
                        bplotter=False,
                        bavg=True,
                        bsummary=False,
                    )

        #save validation loss
        if total_cells == 0:
            pq_weight = 0
        else:
            pq_weight = pq_sum / total_cells

        print(f"PQ: {pq_weight:0.4f}, {pq_sum:0.4f}, {total_cells}")

        self.vallosses = self.logger_val.info['loss']['loss'].avg
        acc = self.logger_val.info['metrics']['accs'].avg
        #pq = pq_weight

        self.logger_val.logger(
            epoch,
            epoch,
            i,
            len(data_loader),
            batch_time,
            bplotter=True,
            bavg=True,
            bsummary=True,
        )

        #vizual_freq
        if epoch % self.view_freq == 0:

            prob = F.softmax(outputs.cpu().float(), dim=1)
            prob = prob.data[0]
            maxprob = torch.argmax(prob, 0)

            self.visheatmap.show('Label',
                                 targets.data.cpu()[0].numpy()[1, :, :])
            #self.visheatmap.show('Weight map', weights.data.cpu()[0].numpy()[0,:,:])
            self.visheatmap.show('Image',
                                 inputs.data.cpu()[0].numpy()[0, :, :])
            self.visheatmap.show('Max prob',
                                 maxprob.cpu().numpy().astype(np.float32))
            for k in range(prob.shape[0]):
                self.visheatmap.show('Heat map {}'.format(k),
                                     prob.cpu()[k].numpy())

        print(f"End Val: wPQ{pq_weight}")
        return pq_weight
    def training(self, data_loader, epoch=0):
        #reset logger
        self.logger_train.reset()
        data_time = AverageMeter()
        batch_time = AverageMeter()

        # switch to evaluate mode
        self.net.train()

        end = time.time()
        for i, sample in enumerate(data_loader):

            # measure data loading time
            data_time.update(time.time() - end)
            # get data (image, label, weight)
            inputs, targets = sample['image'], sample['label']
            weights = None
            if 'weight' in sample.keys():
                weights = sample['weight']

            batch_size = inputs.shape[0]

            if self.cuda:
                inputs = inputs.cuda()
                targets = targets.cuda()
                if type(weights) is not type(None):
                    weights = weights.cuda()

            # fit (forward)
            if self.half_precision:
                with torch.cuda.amp.autocast():

                    loss, outputs = self.step(inputs, targets)

                    self.optimizer.zero_grad()
                    self.scaler.scale(loss * batch_size).backward()
                    self.scaler.step(self.optimizer)
                    self.scaler.update()

            else:
                loss, outputs = self.step(inputs, targets)

                self.optimizer.zero_grad()
                (batch_size * loss).backward()  #batch_size
                self.optimizer.step()

            accs = self.accuracy(outputs, targets)
            dices = self.dice(outputs, targets)
            #pq    = metrics.pq_metric(outputs.cpu().detach().numpy(), targets.cpu().detach().numpy())
            #pq, n_cells  = metrics.pq_metric(targets, outputs)

            # update
            self.logger_train.update(
                {'loss': loss.item()},
                {
                    'accs': accs.item(),
                    #'PQ': pq,
                    'dices': dices.item()
                },
                batch_size,
            )

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.print_freq == 0:
                self.logger_train.logger(
                    epoch,
                    epoch + float(i + 1) / len(data_loader),
                    i,
                    len(data_loader),
                    batch_time,
                )