Ejemplo n.º 1
0
class Predictor:
    def __init__(self,
                 data_dir,
                 ckp_path,
                 save_lr=False,
                 list_dir='',
                 out_dir='./',
                 log_dir=''):
        super(Predictor, self).__init__()
        self.data_dir = data_dir
        self.list_dir = list_dir
        self.out_dir = out_dir
        self.checkpoint = ckp_path
        self.save_lr = save_lr

        self.logger = Logger(scrn=True, log_dir=log_dir, phase='test')

        self.model = None

    def test_epoch(self):
        raise NotImplementedError

    def test(self):
        if self.checkpoint:
            if self._resume_from_checkpoint():
                self.model.cuda()
                self.model.eval()
                self.test_epoch()
        else:
            self.logger.warning("no checkpoint assigned!")

    def _resume_from_checkpoint(self):
        if not os.path.isfile(self.checkpoint):
            self.logger.error("=> no checkpoint found at '{}'".format(
                self.checkpoint))
            return False

        self.logger.show("=> loading checkpoint '{}'".format(self.checkpoint))
        checkpoint = torch.load(self.checkpoint)

        state_dict = self.model.state_dict()
        ckp_dict = checkpoint.get('state_dict', checkpoint)

        try:
            state_dict.update(ckp_dict)
            self.model.load_state_dict(ckp_dict)
        except KeyError as e:
            self.logger.error("=> mismatched checkpoint for test")
            self.logger.error(e)
            return False
        else:
            self.epoch = checkpoint.get('epoch', 0)

        self.logger.show("=> loaded checkpoint '{}'".format(self.checkpoint))
        return True
Ejemplo n.º 2
0
class Predictor:
    modes = ['dataloader', 'folder', 'list', 'file', 'data']

    def __init__(self,
                 model=None,
                 mode='folder',
                 save_dir=None,
                 scrn=True,
                 log_dir=None,
                 cuda_off=False):

        self.save_dir = save_dir
        self.output = None

        if not cuda_off and torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')

        assert model is not None, "The model must be assigned"
        self.model = self._model_init(model)

        if mode not in Predictor.modes:
            raise NotImplementedError

        self.logger = Logger(scrn=scrn, log_dir=log_dir, phase='predict')

        if mode == 'dataloader':
            self._predict = partial(self._predict_dataloader,
                                    dataloader=None,
                                    save_dir=save_dir)
        elif mode == 'folder':
            # self.suffix = ['.jpg', '.png', '.bmp', '.gif', '.npy']  # 支持的图像格式
            self._predict = partial(self._predict_folder, save_dir=save_dir)
        elif mode == 'list':
            self._predict = partial(self._predict_list, save_dir=save_dir)
        elif mode == 'file':
            self._predict = partial(self._predict_file, save_dir=save_dir)
        elif mode == 'data':
            self._predict = partial(self._predict_data, save_dir=save_dir)
        else:
            raise NotImplementedError

    def __call__(self, *args, **kwargs):
        return self._predict(*args, **kwargs)

    def _model_init(self, model):
        model.to(self.device)
        model.eval()
        return model

    def _load_data(self, path):
        return io.imread(path)

    def _to_tensor(self, arr):
        return to_tensor(arr)

    def _to_array(self, tensor):
        return to_array(tensor)

    def _normalize(self, tensor):
        return normalize(tensor)

    def _np2tensor(self, arr):
        nor_tensor = self._normalize(self._to_tensor(arr))
        assert isinstance(nor_tensor, torch.Tensor)
        return nor_tensor

    def _save_data_NTIRE2020(self, data, path):
        s_dir = os.path.dirname(path)
        if not os.path.exists(s_dir):
            os.mkdir(s_dir)
        path = path.replace('_clean.png',
                            '.mat').replace('_RealWorld.png', '.mat')
        if isinstance(data, torch.Tensor):
            data = self._to_array(data).squeeze()

        content = {}
        content['cube'] = data
        content['bands'] = np.array([[
            400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520,
            530, 540, 550, 560, 570, 580, 590, 600, 610, 620, 630, 640, 650,
            660, 670, 680, 690, 700
        ]])
        # content['norm_factor'] =
        hdf5.write(data=content,
                   filename=path,
                   store_python_metadata=True,
                   matlab_compatible=True)

    def _save_data(self, data, path):
        s_dir = os.path.dirname(path)
        if not os.path.exists(s_dir):
            os.mkdir(s_dir)

        torchvision.utils.save_image(data, path)

    def predict_base(self, model, data, path=None):
        start = time.time()
        with torch.no_grad():
            output = model(data)
        torch.cuda.synchronize()
        su_time = time.time() - start
        if path:
            self._save_data_NTIRE2020(output, path)

        self.output = output
        return output, su_time

    def _predict_dataloader(self, dataloader, save_dir=None):
        assert dataloader is not None, \
            "In 'dataloader' mode the input must be a valid dataloader!"
        consume_time = AverageMeter()
        pb = tqdm(dataloader)
        for idx, (name, data) in enumerate(pb):
            assert isinstance(data, torch.Tensor) and data.dim() == 4,\
            "input data must be 4-dimention tensor"
            data = data.to(self.device)  # 4-d tensor
            save_path = os.path.join(save_dir, name) if save_dir else None
            _, su_time = self.predict_base(self.model, data, path=save_path)
            consume_time.update(su_time, n=1)

            # logger
            description = (
                "[{}/{}] speed: {time.val:.4f}s({time.avg:.4f}s)".format(
                    idx + 1, len(dataloader.dataset), time=consume_time))
            pb.set_description(description)
            self.logger.dump(description)

    def _predict_folder(self, folder, save_dir=None):
        assert folder is not None and os.path.isdir(folder),\
        "In 'folder' mode the input must be a valid path of a folder!"
        consume_time = AverageMeter()
        file_list = glob.glob(os.path.join(folder, '*'))

        assert not len(file_list) == 0, "The input folder is empty"

        pb = tqdm(file_list)  # processbar

        for idx, file in enumerate(pb):
            img = self._load_data(file)
            name = os.path.basename(file)
            img = self._np2tensor(img).unsqueeze(0).to(self.device)
            save_path = os.path.join(save_dir, name) if save_dir else None
            _, su_time = self.predict_base(model=self.model,
                                           data=img,
                                           path=save_path)
            consume_time.update(su_time)

            # logger
            description = (
                "[{}/{}] speed: {time.val:.4f}s({time.avg:.4f}s)".format(
                    idx + 1, len(file_list), time=consume_time))
            pb.set_description(description)
            self.logger.dump(description)

    def _predict_list(self, file_list, save_dir=None):
        assert isinstance(file_list, list),\
        "In 'list' mode the input must be a valid file_path list!"
        consume_time = AverageMeter()

        assert not len(file_list) == 0, "The input file list is empty!"

        pb = tqdm(file_list)  # processbar

        for idx, path in enumerate(pb):
            data = self._load_data(path)
            name = os.path.basename(path)
            data = self._np2tensor(data).unsqueeze(0).to(self.device)
            path = os.path.join(save_dir, name) if save_dir else None
            _, su_time = self.predict_base(model=self.model,
                                           data=data,
                                           path=path)
            consume_time.update(su_time, n=1)

            # logger
            description = (
                "[{}/{}] speed: {time.val:.4f}s({time.avg:.4f}s)".format(
                    idx + 1, len(file_list), time=consume_time))
            pb.set_description(description)
            self.logger.dump(description)

    def _predict_file(self, file_path, save_dir=None):
        assert isinstance(file_path, str) and os.path.isfile(file_path), \
        "In 'file' mode the input must a valid path of a file!"

        consume_time = AverageMeter()
        data = self._load_data(file_path)
        name = os.path.basename(file_path)
        data = self._np2tensor(data).unsqueeze(0).to(self.device)
        path = os.path.join(save_dir, name) if save_dir else None

        _, su_time = self.predict_base(model=self.model, data=data, path=path)
        consume_time.update(su_time)

        # logger
        description = ("file: {}  speed: {time.val:.4f}s".format(
            name, time=consume_time))

        self.logger.show(description)

    def _predict_data(self, data):
        """
        :return: tensor
        """

        assert isinstance(data, torch.Tensor) and data.dim() == 4, \
        "In 'data' mode the input must be a 4-d tensor"

        consume_time = AverageMeter()
        output, su_time = self.predict_base(model=self.model, data=data)

        consume_time.update(su_time)

        # logger
        description = ("speed: {time.val:.4f}s".format(time=consume_time))

        self.logger.dump(description)

        return output
Ejemplo n.º 3
0
class Trainer:
    def __init__(self, settings):
        super(Trainer, self).__init__()
        self.settings = settings
        self.phase = settings.cmd
        self.batch_size = settings.batch_size
        self.data_dir = settings.data_dir
        self.list_dir = settings.list_dir
        self.checkpoint = settings.resume
        self.load_checkpoint = (len(self.checkpoint) > 0)
        self.num_epochs = settings.num_epochs
        self.lr = float(settings.lr)
        self.save = settings.save_on or settings.out_dir
        self.from_pause = self.settings.continu
        self.path_ctrl = settings.global_path
        self.path = self.path_ctrl.get_path

        log_dir = '' if settings.log_off else self.path_ctrl.get_dir('log')
        self.logger = Logger(scrn=True, log_dir=log_dir, phase=self.phase)

        for k, v in sorted(settings.__dict__.items()):
            self.logger.show("{}: {}".format(k, v))

        self.start_epoch = 0
        self._init_max_acc = 0.0

        self.model = None
        self.criterion = None

    def train_epoch(self):
        raise NotImplementedError

    def validate_epoch(self, epoch, store):
        raise NotImplementedError

    def train(self):
        cudnn.benchmark = True

        if self.load_checkpoint:
            self._resume_from_checkpoint()
        max_acc = self._init_max_acc
        best_epoch = self.get_ckp_epoch()

        self.model.cuda()
        self.criterion.cuda()

        end_epoch = self.num_epochs if self.from_pause else self.start_epoch + self.num_epochs
        for epoch in range(self.start_epoch, end_epoch):
            lr = self._adjust_learning_rate(epoch)

            self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr))
            # Train for one epoch
            self.train_epoch()

            # Evaluate the model on validation set
            self.logger.show_nl("Validate")
            acc = self.validate_epoch(epoch=epoch, store=self.save)

            is_best = acc > max_acc
            if is_best:
                max_acc = acc
                best_epoch = epoch
            self.logger.show_nl(
                "Current: {:.6f} ({:03d})\tBest: {:.6f} ({:03d})\t".format(
                    acc, epoch, max_acc, best_epoch))

            # The checkpoint saves next epoch
            self._save_checkpoint(self.model.state_dict(), max_acc, epoch + 1,
                                  is_best)

    def validate(self):
        if self.checkpoint:
            if self._resume_from_checkpoint():
                self.model.cuda()
                self.criterion.cuda()
                self.validate_epoch(self.get_ckp_epoch(), self.save)
        else:
            self.logger.warning("no checkpoint assigned!")

    def _load_pretrained(self):
        raise NotImplementedError

    def _adjust_learning_rate(self, epoch):
        # Note that this does not take effect for separate learning rates
        start_epoch = 0 if self.from_pause else self.start_epoch
        if self.settings.lr_mode == 'step':
            lr = self.lr * (0.5**((epoch - start_epoch) // self.settings.step))
        elif self.settings.lr_mode == 'poly':
            lr = self.lr * (1 - (epoch - start_epoch) /
                            (self.num_epochs - start_epoch))**1.1
        elif self.settings.lr_mode == 'const':
            lr = self.lr
        else:
            raise ValueError('unknown lr mode {}'.format(
                self.settings.lr_mode))

        if lr == self.lr:
            return self.lr

        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return lr

    def _resume_from_checkpoint(self):
        if not os.path.isfile(self.checkpoint):
            self.logger.error("=> no checkpoint found at '{}'".format(
                self.checkpoint))
            return False

        self.logger.show("=> loading checkpoint '{}'".format(self.checkpoint))
        checkpoint = torch.load(self.checkpoint)

        state_dict = self.model.state_dict()
        ckp_dict = checkpoint.get('state_dict', checkpoint)
        update_dict = {
            k: v
            for k, v in ckp_dict.items()
            if k in state_dict and state_dict[k].shape == v.shape
        }

        num_to_update = len(update_dict)
        if (num_to_update < len(state_dict)) or (len(state_dict) <
                                                 len(ckp_dict)):
            if self.phase == 'val':
                self.logger.error("=> mismatched checkpoint for validation")
                return False
            self.logger.warning(
                "warning: trying to load an mismatched checkpoint")
            if num_to_update == 0:
                self.logger.error("=> no parameter is to be loaded")
                return False
            else:
                self.logger.warning(
                    "=> {} params are to be loaded".format(num_to_update))
        elif (not self.settings.anew) or (self.phase != 'train'):
            # Note in the non-anew mode, it is not guaranteed that the contained field
            # max_acc be the corresponding one of the loaded checkpoint.
            self.start_epoch = checkpoint.get('epoch', self.start_epoch)
            self._init_max_acc = checkpoint.get('max_acc', self._init_max_acc)

        state_dict.update(update_dict)
        self.model.load_state_dict(state_dict)

        self.logger.show(
            "=> loaded checkpoint '{}' (epoch {}, max_acc {:.4f})".format(
                self.checkpoint, self.get_ckp_epoch(), self._init_max_acc))
        return True

    def _save_checkpoint(self, state_dict, max_acc, epoch, is_best):
        state = {'epoch': epoch, 'state_dict': state_dict, 'max_acc': max_acc}
        # Save history
        history_path = self.path('weight',
                                 CKP_COUNTED.format(e=epoch, s=self.scale),
                                 underline=True)
        if (epoch - self.start_epoch) % self.settings.trace_freq == 0:
            torch.save(state, history_path)
        # Save latest
        latest_path = self.path('weight',
                                CKP_LATEST.format(s=self.scale),
                                underline=True)
        torch.save(state, latest_path)
        if is_best:
            shutil.copyfile(
                latest_path,
                self.path('weight',
                          CKP_BEST.format(s=self.scale),
                          underline=True))

    def get_ckp_epoch(self):
        # Get current epoch of the checkpoint
        # For dismatched ckp or no ckp, set to 0
        return max(self.start_epoch - 1, 0)