def valid(self, loader, epoch): """Validate one step of model. :param loader: valid data loader :type loader: DataLoader :param epoch: current epoch. :type epoch: int :return: performance. :rtype: type """ metrics = Metrics(self.cfg.metric) model = self.model_ema.ema if self.use_ema else self.model model.eval() with torch.no_grad(): for _, (input, target) in enumerate(loader): if self.cfg.cuda and not self.cfg.prefetcher: input, target = input.cuda(), target.cuda() logits = model(input) metrics(logits, target) prec = metrics.results if self.horovod: prec = [ self._metric_average(acc, self.cfg.metric.type) for acc in prec ] return prec
def _train(self, model): """Train Parameter Sharing model with train and valid data. :param model: parameter sharing super model :type model: torch.nn.Module """ metrics = Metrics(self.cfg.metric) model.train() valid_loader_iter = iter(self.valid_loader) step = 0 for (train_input, train_target) in self.train_loader: try: valid_input, valid_target = next(valid_loader_iter) except Exception: valid_loader_iter = iter(self.valid_loader) valid_input, valid_target = next(valid_loader_iter) train_input, train_target = train_input.to( self.device), train_target.to(self.device) valid_input, valid_target = valid_input.to( self.device), valid_target.to(self.device) self._train_arch_step(train_input, train_target, valid_input, valid_target) train_logits = self._train_model_step(train_input, train_target) metrics(train_logits, train_target) top1 = metrics.results[0] if self._first_rank and step % self.cfg.print_step_interval == 0: logging.info("step [{}/{}], top1: {}".format( step + 1, len(self.train_loader), top1)) step += 1
def train(self, loader, epoch): """Train one step of model. :param loader: train data loader :type loader: DataLoader :param epoch: current epoch. :type epoch: int """ metrics = Metrics(self.cfg.metric) self.model.train() loss_sum = 0. data_num = 0 num_updates = epoch * len(loader) for step, (input, target) in enumerate(loader): if self.cfg.cuda and not self.cfg.prefetcher: input, target = input.cuda(), target.cuda() self.optimizer.zero_grad() logits = self.model(input) loss = self.loss(logits, target) if self.use_amp: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() self.optimizer.synchronize() with self.optimizer.skip_synchronize(): self.optimizer.step() else: loss.backward() self.optimizer.step() if self.use_ema: self.model_ema.update(self.model) metrics(logits, target) n = input.size(0) data_num += n loss_sum += loss.item() * n loss_avg = loss_sum / data_num lrl = [ param_group['lr'] for param_group in self.optimizer.param_groups ] lr = sum(lrl) / len(lrl) prec = metrics.results if self._first_rank and step % self.cfg.report_freq == 0: logging.info( "step [%d/%d], top1 [%f], top5 [%f], loss avg [%f], lr [%f]", step, len(loader), prec[0], prec[1], loss_avg, lr) num_updates += 1 if self.cfg.use_timm_lr_sched: self.lr_scheduler.step_update(num_updates=num_updates)
def _valid(self, model, loader): """Validate Parameter Sharing model with data. :param model: network model :type model: torch.nn.Module :param loader: data loader :type loader: DataLoader :return: top1, top5 :rtype: float, float """ metrics = Metrics(self.cfg.metric) model.eval() with torch.no_grad(): for _, (input, target) in enumerate(loader): input, target = input.to(self.device), target.to(self.device) logits = model(input) metrics(logits, target) top1 = metrics.results[0] top5 = metrics.results[1] return top1, top5
def _init_metrics(self, metrics=None): """Init metrics.""" if metrics is not None: return metrics else: return Metrics(self.cfg.metric)