Exemple #1
0
    def valid(self, loader, epoch):
        """Validate one step of model.

        :param loader: valid data loader
        :type loader: DataLoader
        :param epoch: current epoch.
        :type epoch: int
        :return: performance.
        :rtype: type
        """
        metrics = Metrics(self.cfg.metric)
        model = self.model_ema.ema if self.use_ema else self.model
        model.eval()
        with torch.no_grad():
            for _, (input, target) in enumerate(loader):
                if self.cfg.cuda and not self.cfg.prefetcher:
                    input, target = input.cuda(), target.cuda()
                logits = model(input)
                metrics(logits, target)
        prec = metrics.results
        if self.horovod:
            prec = [
                self._metric_average(acc, self.cfg.metric.type) for acc in prec
            ]
        return prec
Exemple #2
0
    def _train(self, model):
        """Train Parameter Sharing model with train and valid data.

        :param model: parameter sharing super model
        :type model: torch.nn.Module
        """
        metrics = Metrics(self.cfg.metric)
        model.train()
        valid_loader_iter = iter(self.valid_loader)
        step = 0
        for (train_input, train_target) in self.train_loader:
            try:
                valid_input, valid_target = next(valid_loader_iter)
            except Exception:
                valid_loader_iter = iter(self.valid_loader)
                valid_input, valid_target = next(valid_loader_iter)
            train_input, train_target = train_input.to(
                self.device), train_target.to(self.device)
            valid_input, valid_target = valid_input.to(
                self.device), valid_target.to(self.device)
            self._train_arch_step(train_input, train_target, valid_input,
                                  valid_target)
            train_logits = self._train_model_step(train_input, train_target)
            metrics(train_logits, train_target)
            top1 = metrics.results[0]
            if self._first_rank and step % self.cfg.print_step_interval == 0:
                logging.info("step [{}/{}], top1: {}".format(
                    step + 1, len(self.train_loader), top1))
            step += 1
Exemple #3
0
    def train(self, loader, epoch):
        """Train one step of model.

        :param loader: train data loader
        :type loader: DataLoader
        :param epoch: current epoch.
        :type epoch: int
        """
        metrics = Metrics(self.cfg.metric)
        self.model.train()
        loss_sum = 0.
        data_num = 0
        num_updates = epoch * len(loader)
        for step, (input, target) in enumerate(loader):
            if self.cfg.cuda and not self.cfg.prefetcher:
                input, target = input.cuda(), target.cuda()
            self.optimizer.zero_grad()
            logits = self.model(input)
            loss = self.loss(logits, target)
            if self.use_amp:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
                    self.optimizer.synchronize()
                with self.optimizer.skip_synchronize():
                    self.optimizer.step()
            else:
                loss.backward()
                self.optimizer.step()
            if self.use_ema:
                self.model_ema.update(self.model)
            metrics(logits, target)
            n = input.size(0)
            data_num += n
            loss_sum += loss.item() * n
            loss_avg = loss_sum / data_num
            lrl = [
                param_group['lr']
                for param_group in self.optimizer.param_groups
            ]
            lr = sum(lrl) / len(lrl)
            prec = metrics.results
            if self._first_rank and step % self.cfg.report_freq == 0:
                logging.info(
                    "step [%d/%d], top1 [%f], top5 [%f], loss avg [%f], lr [%f]",
                    step, len(loader), prec[0], prec[1], loss_avg, lr)
            num_updates += 1
            if self.cfg.use_timm_lr_sched:
                self.lr_scheduler.step_update(num_updates=num_updates)
Exemple #4
0
    def _valid(self, model, loader):
        """Validate Parameter Sharing model with data.

        :param model: network model
        :type model: torch.nn.Module
        :param loader: data loader
        :type loader: DataLoader
        :return: top1, top5
        :rtype: float, float
        """
        metrics = Metrics(self.cfg.metric)
        model.eval()
        with torch.no_grad():
            for _, (input, target) in enumerate(loader):
                input, target = input.to(self.device), target.to(self.device)
                logits = model(input)
                metrics(logits, target)
        top1 = metrics.results[0]
        top5 = metrics.results[1]
        return top1, top5
Exemple #5
0
 def _init_metrics(self, metrics=None):
     """Init metrics."""
     if metrics is not None:
         return metrics
     else:
         return Metrics(self.cfg.metric)