예제 #1
0
    def train_epoch(self, epoch=0):
        losses = AverageMeter()
        len_train = len(self.train_loader)
        pb = tqdm(self.train_loader)

        self.model.train()

        for i, (_, hsi) in enumerate(pb):
            hsi = hsi.to(self.device)
            sens, idx = create_sensitivity('D')
            sens, idx = sens.to(self.device), torch.LongTensor([idx]).to(self.device)
            rgb = create_rgb(sens, hsi)

            kls = self.model(rgb)

            loss = self.criterion(kls.unsqueeze(0), idx)
            losses.update(loss.item(), n=self.batch_size)

            # Compute gradients and do SGD step
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            desc = self.logger.make_desc(
                i + 1, len_train,
                ('loss', losses, '.4f'),
            )

            pb.set_description(desc)
            self.logger.dump(desc)

            # @zjw: tensorboard
            self.logger.add_scalar('Classifier-Loss/train/losses', losses.val, len_train * epoch + i)
            self.logger.add_scalar('Classifier-Lr', self.optimizer.param_groups[0]['lr'], epoch * len_train + i)
예제 #2
0
    def train_epoch(self, epoch=0):
        losses = AverageMeter()
        len_train = len(self.train_loader)
        pb = tqdm(self.train_loader)

        self.model.train()
        self.logger.watch_grad(model=self.model, layers=[0, 1, -1])

        for i, (_, hsi) in enumerate(pb):
            hsi = hsi.to(self.device)

            if self.sens_type == 'D':
                sens, _ = create_sensitivity('D')
                sens = sens.to(self.device)
            else:
                sens = create_sensitivity('C').to(self.device)

            # Create a RGB image as training input
            rgb = create_rgb(sens, hsi)

            if self.with_sens:
                rgb = torch.cat([rgb, sens.view(1, -1, 1, 1).repeat(rgb.size(0), 1, *rgb.shape[2:])], dim=1)

            recon = self.model(rgb)
            # Discard the boundary pixels of hsi
            # hsi = hsi[..., self.cut:-self.cut, self.cut:-self.cut]

            loss = self.criterion(recon, hsi)
            losses.update(loss.item(), n=self.batch_size)

            # Compute gradients and do SGD step
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            desc = self.logger.make_desc(
                i + 1, len_train,
                ('loss', losses, '.4f'),
            )

            pb.set_description(desc)
            self.logger.dump(desc)

            # @zjw: tensorboard
            self.logger.add_grads(global_step=len_train * epoch + i)
            self.logger.add_scalar('Solver-Loss/train/', losses.val, len_train * epoch + i)
            self.logger.add_scalar('Solver-Lr', self.optimizer.param_groups[0]['lr'], epoch * len_train + i)
예제 #3
0
    def validate_epoch(self, epoch=0, store=False):
        self.logger.show_nl("Epoch: [{0}]".format(epoch))
        losses = AverageMeter()
        len_val = len(self.val_loader)
        pb = tqdm(self.val_loader)

        self.model.eval()

        with torch.no_grad():
            for i, (name, _, hsi) in enumerate(pb):
                hsi = hsi.to(self.device)
                sens, idx = self.sens_list[i]
                sens, idx = sens.to(self.device), torch.LongTensor([idx]).to(self.device)
                img_real = create_rgb(sens, hsi)

                kls = self.model(img_real)
                pred, _ = create_sensitivity('D', torch.argmax(kls, dim=0).item())
                pred = pred.to(self.device)

                img_pred = create_rgb(pred, hsi)

                loss = self.criterion(kls.unsqueeze(0), idx)
                losses.update(loss.item(), n=self.batch_size)

                # img_pred = to_array(img_pred[0])
                # img_real = to_array(img_real[0])

                for m in self.metrics:
                    m.update(img_pred, img_real)

                desc = self.logger.make_desc(
                    i + 1, len_val,
                    ('loss', losses, '.4f'),
                    *(
                        (m.__name__, m, '.4f')
                        for m in self.metrics
                    )
                )

                pb.set_description(desc)
                self.logger.dump(desc)

                # @zjw: tensorboard
                self.logger.add_scalar('Classifier-Loss/validate/losses', losses.val, len_val * epoch + i)
                for m in self.metrics:
                    self.logger.add_scalar('Classifier-validate/metrics/' + m.__name__, m.val, len_val * epoch + i)

                if store:
                    self.logger.add_images('Classifier-validate/real-pred', torch.cat((img_real, img_pred), dim=3),
                                           epoch * len_val + i)
                    self.save_image_tensor(self.gpc.add_suffix(name[0], suffix='real-pred', underline=True),
                                    torch.cat((img_real, img_pred), dim=3), epoch)
                    # self.save_image(self.gpc.add_suffix(name[0], suffix='pred', underline=True),
                    #                 (img_pred * 255).astype('uint8'), epoch)
                    # self.save_image(self.gpc.add_suffix(name[0], suffix='real', underline=True),
                    #                 (img_real * 255).astype('uint8'), epoch)

        return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc)
예제 #4
0
 def __init__(self, dataset, optimizer, settings):
     super().__init__('residual_hyper_inference', dataset, 'MSE', optimizer, settings)
     self.sens_type = self.ctx['sens_type']
     num_feats_in = self.ctx['num_feats_in']
     assert num_feats_in in (3, self.ctx['num_feats_out'] * 3 + 3)
     self.logger.show_nl("Setting up sensitivity functions for validation")
     self.sens_list = [create_sensitivity(self.sens_type) for _ in tqdm(range(len(self.val_loader)))]
     self.with_sens = num_feats_in > 3
     self.chop = self.ctx['chop']
예제 #5
0
    def train_epoch(self, epoch=0):
        losses = AverageMeter()
        smooth_losses = AverageMeter()
        image_losses = AverageMeter()
        label_losses = AverageMeter()
        len_train = len(self.train_loader)
        pb = tqdm(self.train_loader)

        self.model.train()

        for i, (_, hsi) in enumerate(pb):
            # sens for sensitivity and hsi for hyperspectral image
            hsi = hsi.to(self.device)
            sens = create_sensitivity('C').to(self.device)

            # The reconstructed RGB in the range [0,1]
            img_real = create_rgb(sens, hsi)

            pred = self.model(img_real)

            # Reconstruct RGB from sensitivity function and HSI
            img_pred = create_rgb(pred, hsi)

            smooth_loss = self.smooth_criterion(pred, pred.size(0))
            image_loss = self.image_criterion(img_pred, img_real)
            label_loss = self.label_criterion(pred, sens)
            loss = self.calc_total_loss(image_loss, label_loss, smooth_loss)

            losses.update(loss.item(), n=self.batch_size)
            image_losses.update(image_loss.item(), n=self.batch_size)
            label_losses.update(label_loss.item(), n=self.batch_size)
            smooth_losses.update(smooth_loss.item(), n=self.batch_size)

            # Compute gradients and do SGD step
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            desc = self.logger.make_desc(
                i + 1, len_train,
                ('loss', losses, '.4f'),
                ('IL', image_losses, '.4f'),
                ('LL', label_losses, '.4f'),
                ('SL', smooth_losses, '.4f')
            )

            pb.set_description(desc)
            self.logger.dump(desc)

            # @zjw: tensorboard
            self.logger.add_scalar('Estimator-Loss/train/total_losses', losses.val, epoch * len_train + i)
            self.logger.add_scalar('Estimator-Loss/train/image_losses', image_losses.val, epoch * len_train + i)
            self.logger.add_scalar('Estimator-Loss/train/label_losses', label_losses.val, epoch * len_train + i)
            self.logger.add_scalar('Estimator-Loss/train/smooth_losses', smooth_losses.val, epoch * len_train + i)
            self.logger.add_scalar('Estimator-Lr', self.optimizer.param_groups[0]['lr'], epoch * len_train + i)
예제 #6
0
 def __init__(self, dataset, optimizer, settings):
     super().__init__('classifier', dataset, 'NLL', optimizer, settings)
     self.logger.show_nl("Setting up sensitivity functions for validation")
     self.sens_list = [create_sensitivity('D') for _ in tqdm(range(len(self.val_loader)))]
예제 #7
0
 def __init__(self, dataset, optimizer, settings):
     super().__init__('estimator', dataset, 'SMOOTH+MSE+MSE', optimizer, settings)
     self.smooth_criterion, self.image_criterion, self.label_criterion = self.criterion
     self.logger.show_nl("Setting up sensitivity functions for validation")
     self.sens_list = [create_sensitivity('C') for _ in tqdm(range(len(self.val_loader)))]