Ejemplo n.º 1
0
    def validation_epoch_end(self, validation_step_outputs):
        results = {}
        for res in validation_step_outputs:
            results.update(res)

        eval_results = self.evaluator.evaluate(results,
                                               self.cfg.save_dir,
                                               self.current_epoch,
                                               self._logger,
                                               rank=self.local_rank)
        metric = eval_results[self.cfg.evaluator.save_key]

        # ------save best model--------
        if metric > self.save_flag:
            self.save_flag = metric
            best_save_path = os.path.join(self.cfg.save_dir, 'model_best')
            mkdir(self.local_rank, best_save_path)
            # TODO: replace with saving checkpoint
            save_model(self.local_rank, self.model,
                       os.path.join(best_save_path, 'model_best.pth'),
                       self.current_epoch + 1, self.global_step)
            txt_path = os.path.join(best_save_path, "eval_results.txt")
            if self.local_rank < 1:
                with open(txt_path, "a") as f:
                    f.write("Epoch:{}\n".format(self.current_epoch + 1))
                    for k, v in eval_results.items():
                        f.write("{}: {}\n".format(k, v))
        else:
            warnings.warn(
                'Warning! Save_key is not in eval results! Only save model last!'
            )
Ejemplo n.º 2
0
    def run(self, train_loader, val_loader, evaluator):
        """
        start running
        :param train_loader:
        :param val_loader:
        :param evaluator:
        """
        start_epoch = self.epoch
        save_flag = -10
        if self.cfg.schedule.warmup.steps > 0 and start_epoch == 1:
            self.logger.log('Start warming up...')
            self.warm_up(train_loader)
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.cfg.schedule.optimizer.lr

        self._init_scheduler()
        self.lr_scheduler.last_epoch = start_epoch - 1
        
        # resume learning rate of last epoch
        if start_epoch > 1:
            for param_group, lr in zip(self.optimizer.param_groups, self.lr_scheduler.get_lr()):
                param_group['lr'] = lr

        for epoch in range(start_epoch, self.cfg.schedule.total_epochs + 1):
            results, train_loss_dict = self.run_epoch(epoch, train_loader, mode='train')
            self.lr_scheduler.step()
            save_model(self.rank, self.model, os.path.join(self.cfg.save_dir, 'model_last.pth'), epoch, self._iter, self.optimizer)
            for k, v in train_loss_dict.items():
                self.logger.scalar_summary('Epoch_loss/' + k, 'train', v, epoch)

            # --------evaluate----------
            if self.cfg.schedule.val_intervals > 0 and epoch % self.cfg.schedule.val_intervals == 0:
                with torch.no_grad():
                    results, val_loss_dict = self.run_epoch(self.epoch, val_loader, mode='val')
                for k, v in val_loss_dict.items():
                    self.logger.scalar_summary('Epoch_loss/' + k, 'val', v, epoch)
                eval_results = evaluator.evaluate(results, self.cfg.save_dir, epoch, self.logger, rank=self.rank)
                if self.cfg.evaluator.save_key in eval_results:
                    metric = eval_results[self.cfg.evaluator.save_key]
                    if metric > save_flag:
                        # ------save best model--------
                        save_flag = metric
                        best_save_path = os.path.join(self.cfg.save_dir, 'model_best')
                        mkdir(self.rank, best_save_path)
                        save_model(self.rank, self.model, os.path.join(best_save_path, 'model_best.pth'), epoch,
                                   self._iter, self.optimizer)
                        txt_path = os.path.join(best_save_path, "eval_results.txt")
                        if self.rank < 1:
                            with open(txt_path, "a") as f:
                                f.write("Epoch:{}\n".format(epoch))
                                for k, v in eval_results.items():
                                    f.write("{}: {}\n".format(k, v))
                else:
                    warnings.warn('Warning! Save_key is not in eval results! Only save model last!')
            self.epoch += 1
Ejemplo n.º 3
0
    def run_epoch(self, epoch, data_loader, mode):
        """
        train or validate one epoch
        :param epoch: current epoch number
        :param data_loader: data-loader of train or test dataset
        :param mode: train or val or test
        :return: outputs and a dict of epoch average losses
        """
        model = self.model
        if mode == 'train':
            model.train()  # network in train mode
            if self.rank > -1:  # Using distributed training, need to set epoch for sampler
                self.logger.log(
                    "distributed sampler set epoch at {}".format(epoch))
                data_loader.sampler.set_epoch(epoch)
        else:
            model.eval()
            torch.cuda.empty_cache()

        ret_dict = {}
        epoch_losses = {}
        step_losses = {}
        num_iters = len(data_loader)

        for iter_id, meta in enumerate(data_loader):
            if iter_id >= num_iters:
                break

            meta['img'] = meta['img'].to(device=torch.device('cuda'),
                                         non_blocking=True)
            output, loss, loss_stats = self.run_step(model, meta, mode)

            if mode == 'val':  # TODO: eval
                dets_dict = model.module.head.post_process(output, meta)
                ret_dict[meta['img_info']['id'].cpu().numpy()[0]] = dets_dict

            for k in loss_stats:
                if k not in epoch_losses:
                    epoch_losses[k] = AverageMeter(loss_stats[k].mean().item())
                    step_losses[k] = MovingAverage(
                        loss_stats[k].mean().item(),
                        window_size=self.cfg.log.interval)
                else:
                    epoch_losses[k].update(loss_stats[k].mean().item())
                    step_losses[k].push(loss_stats[k].mean().item())

            if iter_id % self.cfg.log.interval == 0:
                log_msg = '{}|Epoch{}/{}|Iter{}({}/{})| lr:{:.3e}| '.format(
                    mode, epoch, self.cfg.schedule.total_epochs, self._iter,
                    iter_id, num_iters, self.optimizer.param_groups[0]['lr'])
                for l in step_losses:
                    log_msg += '{}:{:.4f}| '.format(l, step_losses[l].avg())
                    if mode == 'train' and self.rank < 1:
                        self.logger.scalar_summary('Train_loss/' + l, mode,
                                                   step_losses[l].avg(),
                                                   self._iter)
                self.logger.log(log_msg)

            # ----- save checkpoint in an epoch
            if iter_id % self.cfg.schedule.save_interval == 0:
                save_model(
                    self.rank, self.model,
                    os.path.join(
                        self.cfg.save_dir,
                        'epoch{:d}_iter{:d}.pth'.format(epoch, iter_id)),
                    epoch, self._iter, self.optimizer)

            if mode == 'train':
                self._iter += 1

            del output, loss, loss_stats

        epoch_loss_dict = {k: v.avg for k, v in epoch_losses.items()}

        return ret_dict, epoch_loss_dict
Ejemplo n.º 4
0
    def run(self, train_loader, val_loader, evaluator):
        """
        start running
        :param train_loader:
        :param val_loader:
        :param evaluator:
        """
        start_epoch = self.epoch
        save_flag = -10
        if self.cfg.schedule.warmup.steps > 0 and start_epoch == 1:
            self.logger.log('Start warming up...')
            self.warm_up(train_loader)
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.cfg.schedule.optimizer.lr

        self._init_scheduler()
        self.lr_scheduler.last_epoch = start_epoch - 1

        # ---------- traverse each epoch
        for epoch_i, epoch in enumerate(
                range(start_epoch, self.cfg.schedule.total_epochs + 1)):
            # # ----- validate before training actually starts
            # ret_dict, val_loss_dict = self.run_epoch(self.epoch, val_loader, mode='val')
            # if self.cfg.evaluator.name == 'MyDetectionEvaluator':
            #     evaluator.evaluate(ret_dict)

            # ----- run an epoch on train dataset, schedule lr, save model and logging
            ret_dict, train_loss_dict = self.run_epoch(epoch,
                                                       train_loader,
                                                       mode='train')
            self.lr_scheduler.step()
            save_model(self.rank, self.model,
                       os.path.join(self.cfg.save_dir, 'model_last.pth'),
                       epoch, self._iter, self.optimizer)
            for k, v in train_loss_dict.items():
                self.logger.scalar_summary('Epoch_loss/' + k, 'train', v,
                                           epoch)

            # --------evaluate----------
            if evaluator is None:
                # do not evaluate, save current epoch's checkpoint
                best_save_path = os.path.join(
                    self.cfg.save_dir,
                    'epoch_{:d}'.format(start_epoch + epoch_i))
                mkdir(self.rank, best_save_path)
                save_model(self.rank, self.model,
                           os.path.join(best_save_path, 'model_best.pth'),
                           epoch, self._iter, self.optimizer)
            else:  # do evaluation
                if epoch % self.cfg.schedule.val_intervals == 0:
                    with torch.no_grad(
                    ):  # train an epoch on validation dataset
                        ret_dict, val_loss_dict = self.run_epoch(self.epoch,
                                                                 val_loader,
                                                                 mode='val')

                    for k, v in val_loss_dict.items():
                        self.logger.scalar_summary('Epoch_loss/' + k, 'val', v,
                                                   epoch)

                    # ----- do evaluation, ret_dict, key: img_id, val: dets_dict
                    if self.cfg.evaluator.name == 'CocoDetectionEvaluator':
                        eval_results = evaluator.evaluate(ret_dict,
                                                          self.cfg.save_dir,
                                                          epoch,
                                                          self.logger,
                                                          rank=self.rank)
                    elif self.cfg.evaluator.name == 'MyDetectionEvaluator':
                        eval_results = evaluator.evaluate(ret_dict)

                    if eval_results is None:
                        continue
                    if self.cfg.evaluator.save_key in eval_results:
                        metric = eval_results[self.cfg.evaluator.save_key]
                        if metric > save_flag:
                            # ------save best model--------
                            save_flag = metric
                            best_save_path = os.path.join(
                                self.cfg.save_dir, 'model_best')
                            mkdir(self.rank, best_save_path)
                            save_model(
                                self.rank, self.model,
                                os.path.join(best_save_path, 'model_best.pth'),
                                epoch, self._iter, self.optimizer)
                            txt_path = os.path.join(best_save_path,
                                                    "eval_results.txt")
                            if self.rank < 1:
                                with open(txt_path, "a") as f:
                                    f.write("Epoch:{}\n".format(epoch))
                                    for k, v in eval_results.items():
                                        f.write("{}: {}\n".format(k, v))
                    else:
                        warnings.warn(
                            'Warning! Save_key is not in eval results! Only save model last!'
                        )

            self.epoch += 1