Exemple #1
0
    def _print_log(self, step, log_values, title='', max_n_batch=None):
        log_str = '{}\n'.format(self.params.exp_name)
        log_str += '{}: epoch {}'.format(title, self.last_epoch)

        if max_n_batch:
            log_str += '[{}/{}], lr: {}'.format(step, max_n_batch, get_learning_rates(self.optimizer))

        i = 0
        # global_step = step + (self.last_epoch - 1) * self.batch_per_epoch
        for k, v in log_values.items():
            if isinstance(v, meter_utils.AverageValueMeter):
                mean, std = v.value()
                log_str += '\n\t{}: {:.10f}'.format(k, mean)
                i += 1

        if max_n_batch:
            # print time
            data_time = self.data_timer.duration + 1e-6
            batch_time = self.batch_timer.duration + 1e-6
            rest_seconds = int((max_n_batch - step) * batch_time)
            log_str += '\n\t({:.2f}/{:.2f}s,' \
                       ' fps:{:.1f}, rest: {})'.format(data_time, batch_time,
                                                       self.params.batch_size / batch_time,
                                                       str(datetime.timedelta(seconds=rest_seconds)))
            self.batch_timer.clear()
            self.data_timer.clear()

        logger.info(log_str)
Exemple #2
0
    def _val_one_epoch(self, n_batch):
        training_mode = self.model.training
        self.model.eval()
        logs = OrderedDict()
        sum_loss = meter_utils.AverageValueMeter()
        logger.info('Val on validation set...')

        self.batch_timer.clear()
        self.data_timer.clear()
        self.batch_timer.tic()
        self.data_timer.tic()
        for step, batch in enumerate(self.val_data):
            self.data_timer.toc()
            if step > n_batch:
                break

            inputs, gts, _ = self.batch_processor(self, batch)
            _, saved_for_loss, _ = self.model(*inputs)
            self.batch_timer.toc()

            loss, saved_for_log = self.model.module.build_loss(saved_for_loss, *gts)
            sum_loss.add(loss.item())
            self._process_log(saved_for_log, logs)

            if step % self.params.print_freq == 0:
                self._print_log(step, logs, 'Validation', max_n_batch=min(n_batch, len(self.val_data)))

            self.data_timer.tic()
            self.batch_timer.tic()

        mean, std = sum_loss.value()
        logger.info('Validation loss: mean: {}, std: {}'.format(mean, std))
        self.model.train(mode=training_mode)
        return mean
Exemple #3
0
 def _load_ckpt(self, ckpt):
     epoch, state_dicts = net_utils.load_net(ckpt, self.model, load_state_dict=True)
     if not self.params.ignore_opt_state and not self.params.zero_epoch and epoch >= 0:
         self.last_epoch = epoch
         logger.info('Set last epoch to {}'.format(self.last_epoch))
         if state_dicts is not None:
             self.optimizer.load_state_dict(state_dicts[0])
             net_utils.set_optimizer_state_devices(self.optimizer.state, self.params.gpus[0])
             logger.info('Load optimizer state from checkpoint, '
                         'new learning rate: {}'.format(get_learning_rates(self.optimizer)))
Exemple #4
0
 def _save_ckpt(self, save_to):
     model = self.model.module if isinstance(
         self.model, nn.DataParallel) else self.model
     net_utils.save_net(save_to,
                        model,
                        epoch=self.last_epoch,
                        optimizers=[self.optimizer],
                        rm_prev_opt=True,
                        max_n_ckpts=self.params.save_nckpt_max)
     logger.info('Save ckpt to {}'.format(save_to))
Exemple #5
0
    def __init__(self, model, train_params, batch_processor, train_data, val_data=None):
        assert isinstance(train_params, TrainParams)
        self.params = train_params

        # Data loaders
        self.train_data = train_data
        self.val_data = val_data # sDataLoader.copy(val_data) if isinstance(val_data, DataLoader) else val_data
        # self.val_stream = self.val_data.get_stream() if self.val_data else None

        self.batch_processor = batch_processor
        self.batch_per_epoch = len(self.train_data)

        # set CUDA_VISIBLE_DEVICES=gpus
        gpus = ','.join([str(x) for x in self.params.gpus])
        os.environ['CUDA_VISIBLE_DEVICES'] = gpus
        self.params.gpus = tuple(range(len(self.params.gpus)))
        logger.info('Set CUDA_VISIBLE_DEVICES to {}...'.format(gpus))

        # Optimizer and learning rate
        self.last_epoch = 0
        self.optimizer = self.params.optimizer  # type: Optimizer
        if not isinstance(self.optimizer, Optimizer):
            logger.error('optimizer should be an instance of Optimizer, '
                         'but got {}'.format(type(self.optimizer)))
            raise ValueError
        self.lr_scheduler = self.params.lr_scheduler  # type: ReduceLROnPlateau or _LRScheduler
        if self.lr_scheduler and not isinstance(self.lr_scheduler, (ReduceLROnPlateau, _LRScheduler)):
            logger.error('lr_scheduler should be an instance of _LRScheduler or ReduceLROnPlateau, '
                         'but got {}'.format(type(self.lr_scheduler)))
            raise ValueError
        logger.info('Set lr_scheduler to {}'.format(type(self.lr_scheduler)))

        self.log_values = OrderedDict()
        self.batch_timer = Timer()
        self.data_timer = Timer()

        # load model
        self.model = model
        ckpt = self.params.ckpt
        if not self.params.save_dir:
            self.params.save_dir = os.path.join('outputs', self.params.exp_name)
        mkdir(self.params.save_dir)
        logger.info('Set output dir to {}'.format(self.params.save_dir))
        if ckpt is None:
            # find the last ckpt
            ckpts = [fname for fname in os.listdir(self.params.save_dir) if os.path.splitext(fname)[-1] == '.h5']
            ckpt = os.path.join(
                self.params.save_dir, sorted(ckpts, key=lambda name: int(os.path.splitext(name)[0].split('_')[-1]))[-1]
            ) if len(ckpts) > 0 else None

        if ckpt is not None and not self.params.re_init:
           self._load_ckpt(ckpt)
           logger.info('Load ckpt from {}'.format(ckpt))
        #elif hasattr(self.model, 'init_weight'):
        #    self.model.init_weight()
        #    logger.info("Re-init model weight")

        self.model = ListDataParallel(self.model, device_ids=self.params.gpus)
        self.model = self.model.cuda(self.params.gpus[0])
        self.model.train()
Exemple #6
0
    def __init__(self, model, train_params, batch_processor=None, val_data=None):
        assert isinstance(train_params, TestParams)
        self.params = train_params
        self.batch_timer = Timer()
        self.data_timer = Timer()
        self.val_data = val_data if val_data else None
        self.batch_processor = batch_processor if batch_processor else None

        # load model
        self.model = model
        ckpt = self.params.ckpt

        if ckpt is not None:
            self._load_ckpt(ckpt)
            logger.info('Load ckpt from {}'.format(ckpt))

        self.model = nn.DataParallel(self.model, device_ids=self.params.gpus)
        self.model = self.model.cuda(device=self.params.gpus[0])
        self.model.eval()
def save_net(fname, net, epoch=-1, optimizers=None, rm_prev_opt=False, max_n_ckpts=-1):
    import h5py
    with h5py.File(fname, mode='w') as h5f:
        for k, v in net.state_dict().items():
            h5f.create_dataset(k, data=v.cpu().numpy())
        h5f.attrs['epoch'] = epoch

    if optimizers is not None:
        state_dicts = []
        for optimizer in optimizers:
            state_dict = deepcopy(optimizer.state_dict())
            state_dict['state'] = set_optimizer_state_devices(state_dict['state'], device_id=None)
            state_dicts.append(state_dict)

        state_file = fname + '.optimizer_state.pk'
        with open(state_file, 'wb') as f:
            pickle.dump(state_dicts, f)

        # remove
        if rm_prev_opt:
            root = os.path.split(fname)[0]
            for filename in os.listdir(root):
                filename = os.path.join(root, filename)
                if filename.endswith('.optimizer_state.pk') and filename != state_file:
                    logger.info(('Remove {}'.format(filename)))
                    os.remove(filename)

        # remove ckpt
        if max_n_ckpts > 0:
            root = os.path.split(fname)[0]
            ckpts = [fname for fname in os.listdir(root) if os.path.splitext(fname)[-1] == '.h5']
            ckpts = sorted(ckpts, key=lambda name: int(os.path.splitext(name)[0].split('_')[-1]))
            if len(ckpts) > max_n_ckpts:
                for ckpt in ckpts[0:-max_n_ckpts]:
                    filename = os.path.join(root, ckpt)
                    logger.info('Remove {}'.format(filename))
                    os.remove(filename)
Exemple #8
0
    def train(self):
        best_loss = np.inf
        for epoch in range(self.last_epoch, self.params.max_epoch):
            self.last_epoch += 1
            logger.info('Start training epoch {}'.format(self.last_epoch))

            for fun in self.on_start_epoch_hooks:
                fun(self)

            # adjust learning rate
            if isinstance(self.lr_scheduler, _LRScheduler):
                cur_lrs = get_learning_rates(self.optimizer)
                self.lr_scheduler.step(self.last_epoch)
                logger.info('Set learning rates from {} to {}'.format(
                    cur_lrs, get_learning_rates(self.optimizer)))

            train_loss = self._train_one_epoch()

            for fun in self.on_end_epoch_hooks:
                fun(self)

            # save model
            if (self.last_epoch % self.params.save_freq_epoch
                    == 0) or (self.last_epoch == self.params.max_epoch - 1):
                save_name = 'ckpt_{}.h5'.format(self.last_epoch)
                save_to = os.path.join(self.params.save_dir, save_name)
                self._save_ckpt(save_to)

                # find best model
                if self.params.val_nbatch_end_epoch > 0:
                    val_loss = self._val_one_epoch(
                        self.params.val_nbatch_end_epoch)
                    if val_loss < best_loss:
                        best_file = os.path.join(
                            self.params.save_dir,
                            'ckpt_{}_{:.4f}.h5.best'.format(
                                self.last_epoch, val_loss))
                        shutil.copyfile(save_to, best_file)
                        logger.info('Found a better ckpt ({:.3f} -> {:.3f}), '
                                    'saved to {}'.format(
                                        best_loss, val_loss, best_file))
                        best_loss = val_loss

                    if isinstance(self.lr_scheduler, ReduceLROnPlateau):
                        self.lr_scheduler.step(val_loss, self.last_epoch)