示例#1
0
class Predictor:
    def __init__(self,
                 data_dir,
                 ckp_path,
                 save_lr=False,
                 list_dir='',
                 out_dir='./',
                 log_dir=''):
        super(Predictor, self).__init__()
        self.data_dir = data_dir
        self.list_dir = list_dir
        self.out_dir = out_dir
        self.checkpoint = ckp_path
        self.save_lr = save_lr

        self.logger = Logger(scrn=True, log_dir=log_dir, phase='test')

        self.model = None

    def test_epoch(self):
        raise NotImplementedError

    def test(self):
        if self.checkpoint:
            if self._resume_from_checkpoint():
                self.model.cuda()
                self.model.eval()
                self.test_epoch()
        else:
            self.logger.warning("no checkpoint assigned!")

    def _resume_from_checkpoint(self):
        if not os.path.isfile(self.checkpoint):
            self.logger.error("=> no checkpoint found at '{}'".format(
                self.checkpoint))
            return False

        self.logger.show("=> loading checkpoint '{}'".format(self.checkpoint))
        checkpoint = torch.load(self.checkpoint)

        state_dict = self.model.state_dict()
        ckp_dict = checkpoint.get('state_dict', checkpoint)

        try:
            state_dict.update(ckp_dict)
            self.model.load_state_dict(ckp_dict)
        except KeyError as e:
            self.logger.error("=> mismatched checkpoint for test")
            self.logger.error(e)
            return False
        else:
            self.epoch = checkpoint.get('epoch', 0)

        self.logger.show("=> loaded checkpoint '{}'".format(self.checkpoint))
        return True
示例#2
0
class Trainer:
    def __init__(self, settings):
        super(Trainer, self).__init__()
        self.settings = settings
        self.phase = settings.cmd
        self.batch_size = settings.batch_size
        self.data_dir = settings.data_dir
        self.list_dir = settings.list_dir
        self.checkpoint = settings.resume
        self.load_checkpoint = (len(self.checkpoint) > 0)
        self.num_epochs = settings.num_epochs
        self.lr = float(settings.lr)
        self.save = settings.save_on or settings.out_dir
        self.from_pause = self.settings.continu
        self.path_ctrl = settings.global_path
        self.path = self.path_ctrl.get_path

        log_dir = '' if settings.log_off else self.path_ctrl.get_dir('log')
        self.logger = Logger(scrn=True, log_dir=log_dir, phase=self.phase)

        for k, v in sorted(settings.__dict__.items()):
            self.logger.show("{}: {}".format(k, v))

        self.start_epoch = 0
        self._init_max_acc = 0.0

        self.model = None
        self.criterion = None

    def train_epoch(self):
        raise NotImplementedError

    def validate_epoch(self, epoch, store):
        raise NotImplementedError

    def train(self):
        cudnn.benchmark = True

        if self.load_checkpoint:
            self._resume_from_checkpoint()
        max_acc = self._init_max_acc
        best_epoch = self.get_ckp_epoch()

        self.model.cuda()
        self.criterion.cuda()

        end_epoch = self.num_epochs if self.from_pause else self.start_epoch + self.num_epochs
        for epoch in range(self.start_epoch, end_epoch):
            lr = self._adjust_learning_rate(epoch)

            self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr))
            # Train for one epoch
            self.train_epoch()

            # Evaluate the model on validation set
            self.logger.show_nl("Validate")
            acc = self.validate_epoch(epoch=epoch, store=self.save)

            is_best = acc > max_acc
            if is_best:
                max_acc = acc
                best_epoch = epoch
            self.logger.show_nl(
                "Current: {:.6f} ({:03d})\tBest: {:.6f} ({:03d})\t".format(
                    acc, epoch, max_acc, best_epoch))

            # The checkpoint saves next epoch
            self._save_checkpoint(self.model.state_dict(), max_acc, epoch + 1,
                                  is_best)

    def validate(self):
        if self.checkpoint:
            if self._resume_from_checkpoint():
                self.model.cuda()
                self.criterion.cuda()
                self.validate_epoch(self.get_ckp_epoch(), self.save)
        else:
            self.logger.warning("no checkpoint assigned!")

    def _load_pretrained(self):
        raise NotImplementedError

    def _adjust_learning_rate(self, epoch):
        # Note that this does not take effect for separate learning rates
        start_epoch = 0 if self.from_pause else self.start_epoch
        if self.settings.lr_mode == 'step':
            lr = self.lr * (0.5**((epoch - start_epoch) // self.settings.step))
        elif self.settings.lr_mode == 'poly':
            lr = self.lr * (1 - (epoch - start_epoch) /
                            (self.num_epochs - start_epoch))**1.1
        elif self.settings.lr_mode == 'const':
            lr = self.lr
        else:
            raise ValueError('unknown lr mode {}'.format(
                self.settings.lr_mode))

        if lr == self.lr:
            return self.lr

        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return lr

    def _resume_from_checkpoint(self):
        if not os.path.isfile(self.checkpoint):
            self.logger.error("=> no checkpoint found at '{}'".format(
                self.checkpoint))
            return False

        self.logger.show("=> loading checkpoint '{}'".format(self.checkpoint))
        checkpoint = torch.load(self.checkpoint)

        state_dict = self.model.state_dict()
        ckp_dict = checkpoint.get('state_dict', checkpoint)
        update_dict = {
            k: v
            for k, v in ckp_dict.items()
            if k in state_dict and state_dict[k].shape == v.shape
        }

        num_to_update = len(update_dict)
        if (num_to_update < len(state_dict)) or (len(state_dict) <
                                                 len(ckp_dict)):
            if self.phase == 'val':
                self.logger.error("=> mismatched checkpoint for validation")
                return False
            self.logger.warning(
                "warning: trying to load an mismatched checkpoint")
            if num_to_update == 0:
                self.logger.error("=> no parameter is to be loaded")
                return False
            else:
                self.logger.warning(
                    "=> {} params are to be loaded".format(num_to_update))
        elif (not self.settings.anew) or (self.phase != 'train'):
            # Note in the non-anew mode, it is not guaranteed that the contained field
            # max_acc be the corresponding one of the loaded checkpoint.
            self.start_epoch = checkpoint.get('epoch', self.start_epoch)
            self._init_max_acc = checkpoint.get('max_acc', self._init_max_acc)

        state_dict.update(update_dict)
        self.model.load_state_dict(state_dict)

        self.logger.show(
            "=> loaded checkpoint '{}' (epoch {}, max_acc {:.4f})".format(
                self.checkpoint, self.get_ckp_epoch(), self._init_max_acc))
        return True

    def _save_checkpoint(self, state_dict, max_acc, epoch, is_best):
        state = {'epoch': epoch, 'state_dict': state_dict, 'max_acc': max_acc}
        # Save history
        history_path = self.path('weight',
                                 CKP_COUNTED.format(e=epoch, s=self.scale),
                                 underline=True)
        if (epoch - self.start_epoch) % self.settings.trace_freq == 0:
            torch.save(state, history_path)
        # Save latest
        latest_path = self.path('weight',
                                CKP_LATEST.format(s=self.scale),
                                underline=True)
        torch.save(state, latest_path)
        if is_best:
            shutil.copyfile(
                latest_path,
                self.path('weight',
                          CKP_BEST.format(s=self.scale),
                          underline=True))

    def get_ckp_epoch(self):
        # Get current epoch of the checkpoint
        # For dismatched ckp or no ckp, set to 0
        return max(self.start_epoch - 1, 0)