def __init__(self, criterion, w_optimizer, theta_optimizer, w_scheduler,
                 logger, writer, high):
        self.top1 = AverageMeter()
        self.top3 = AverageMeter()
        self.losses = AverageMeter()
        self.losses_lat = AverageMeter()
        self.losses_ce = AverageMeter()

        self.logger = logger
        self.writer = writer

        self.criterion = criterion
        self.w_optimizer = w_optimizer
        self.theta_optimizer = theta_optimizer
        self.w_scheduler = w_scheduler
        self.high = high

        self.temperature = CONFIG_SUPERNET['train_settings'][
            'init_temperature']
        self.exp_anneal_rate = CONFIG_SUPERNET['train_settings'][
            'exp_anneal_rate']  # apply it every epoch
        self.cnt_epochs = CONFIG_SUPERNET['train_settings']['cnt_epochs']
        self.train_thetas_from_the_epoch = CONFIG_SUPERNET['train_settings'][
            'train_thetas_from_the_epoch']
        self.print_freq = CONFIG_SUPERNET['train_settings']['print_freq']
        if high:
            self.path_to_save_model = CONFIG_SUPERNET['train_settings'][
                'path_to_save_model_high']
        else:
            self.path_to_save_model = CONFIG_SUPERNET['train_settings'][
                'path_to_save_model']
 def __init__(self, criterion, optimizer, scheduler, logger, writer):
     self.top1   = AverageMeter()
     self.top3   = AverageMeter()
     self.losses = AverageMeter()
     
     self.logger = logger
     self.writer = writer
     
     self.optimizer = optimizer
     self.criterion = criterion
     self.scheduler = scheduler
     
     self.path_to_save_model = CONFIG_ARCH['train_settings']['path_to_save_model']
     self.cnt_epochs         = CONFIG_ARCH['train_settings']['cnt_epochs']
     self.print_freq         = CONFIG_ARCH['train_settings']['print_freq']
Exemplo n.º 3
0
class TrainerArch:
    def __init__(self, criterion, optimizer, scheduler, logger, writer):
        self.top1 = AverageMeter()
        self.top3 = AverageMeter()
        self.losses = AverageMeter()

        self.logger = logger
        self.writer = writer

        self.optimizer = optimizer
        self.criterion = criterion
        self.scheduler = scheduler

        self.path_to_save_model = CONFIG_ARCH['train_settings'][
            'path_to_save_model']
        self.cnt_epochs = CONFIG_ARCH['train_settings']['cnt_epochs']
        self.print_freq = CONFIG_ARCH['train_settings']['print_freq']

    def train_loop(self, train_loader, valid_loader, model):
        best_top1 = 0.0

        for epoch in range(self.cnt_epochs):

            self.writer.add_scalar('learning_rate',
                                   self.optimizer.param_groups[0]['lr'], epoch)

            #if epoch and epoch % self.lr_decay_period == 0:
            #    self.optimizer.param_groups[0]['lr'] *= self.lr_decay

            # training
            self._train(train_loader, model, epoch)
            # validation
            top1_avg = self._validate(valid_loader, model, epoch)

            if best_top1 < top1_avg:
                best_top1 = top1_avg
                self.logger.info("Best top1 accuracy by now. Save model")
                save(model, self.path_to_save_model)
            self.scheduler.step()

    def _train(self, loader, model, epoch):
        start_time = time.time()
        model = model.train()

        for step, (X, y) in enumerate(loader):
            X, y = X.cuda(non_blocking=True), y.cuda(non_blocking=True)
            #X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
            N = X.shape[0]

            self.optimizer.zero_grad()
            outs = model(X)
            loss = self.criterion(outs, y)
            loss.backward()
            self.optimizer.step()

            self._intermediate_stats_logging(outs,
                                             y,
                                             loss,
                                             step,
                                             epoch,
                                             N,
                                             len_loader=len(loader),
                                             val_or_train="Train")

        self._epoch_stats_logging(start_time=start_time,
                                  epoch=epoch,
                                  val_or_train='train')
        for avg in [self.top1, self.top3, self.losses]:
            avg.reset()

    def _validate(self, loader, model, epoch):
        model.eval()
        start_time = time.time()

        with torch.no_grad():
            for step, (X, y) in enumerate(loader):
                X, y = X.cuda(), y.cuda()
                N = X.shape[0]

                outs = model(X)
                loss = self.criterion(outs, y)

                self._intermediate_stats_logging(outs,
                                                 y,
                                                 loss,
                                                 step,
                                                 epoch,
                                                 N,
                                                 len_loader=len(loader),
                                                 val_or_train="Valid")

        top1_avg = self.top1.get_avg()
        self._epoch_stats_logging(start_time=start_time,
                                  epoch=epoch,
                                  val_or_train='val')
        for avg in [self.top1, self.top3, self.losses]:
            avg.reset()
        return top1_avg

    def _epoch_stats_logging(self, start_time, epoch, val_or_train):
        self.writer.add_scalar('train_vs_val/' + val_or_train + '_loss',
                               self.losses.get_avg(), epoch)
        self.writer.add_scalar('train_vs_val/' + val_or_train + '_top1',
                               self.top1.get_avg(), epoch)
        self.writer.add_scalar('train_vs_val/' + val_or_train + '_top3',
                               self.top3.get_avg(), epoch)

        top1_avg = self.top1.get_avg()
        self.logger.info(val_or_train +
                         ": [{:3d}/{}] Final Prec@1 {:.4%} Time {:.2f}".format(
                             epoch + 1, self.cnt_epochs, top1_avg,
                             time.time() - start_time))

    def _intermediate_stats_logging(self, outs, y, loss, step, epoch, N,
                                    len_loader, val_or_train):
        prec1, prec3 = accuracy(outs, y, topk=(1, 3))
        self.losses.update(loss.item(), N)
        self.top1.update(prec1.item(), N)
        self.top3.update(prec3.item(), N)

        if (step > 1
                and step % self.print_freq == 0) or step == len_loader - 1:
            self.logger.info(val_or_train +
                             ": [{:3d}/{}] Step {:03d}/{:03d} Loss {:.3f} "
                             "Prec@(1,3) ({:.1%}, {:.1%})".format(
                                 epoch + 1, self.cnt_epochs, step, len_loader -
                                 1, self.losses.get_avg(), self.top1.get_avg(),
                                 self.top3.get_avg()))
class TrainerSupernet:
    def __init__(self, criterion, w_optimizer, theta_optimizer, w_scheduler,
                 logger, writer, high):
        self.top1 = AverageMeter()
        self.top3 = AverageMeter()
        self.losses = AverageMeter()
        self.losses_lat = AverageMeter()
        self.losses_ce = AverageMeter()

        self.logger = logger
        self.writer = writer

        self.criterion = criterion
        self.w_optimizer = w_optimizer
        self.theta_optimizer = theta_optimizer
        self.w_scheduler = w_scheduler
        self.high = high

        self.temperature = CONFIG_SUPERNET['train_settings'][
            'init_temperature']
        self.exp_anneal_rate = CONFIG_SUPERNET['train_settings'][
            'exp_anneal_rate']  # apply it every epoch
        self.cnt_epochs = CONFIG_SUPERNET['train_settings']['cnt_epochs']
        self.train_thetas_from_the_epoch = CONFIG_SUPERNET['train_settings'][
            'train_thetas_from_the_epoch']
        self.print_freq = CONFIG_SUPERNET['train_settings']['print_freq']
        if high:
            self.path_to_save_model = CONFIG_SUPERNET['train_settings'][
                'path_to_save_model_high']
        else:
            self.path_to_save_model = CONFIG_SUPERNET['train_settings'][
                'path_to_save_model']

    def train_loop(self, train_w_loader, train_thetas_loader, test_loader,
                   model):

        best_top1 = 0.0

        # firstly, train weights only
        for epoch in range(self.train_thetas_from_the_epoch):
            self.writer.add_scalar('learning_rate/weights',
                                   self.w_optimizer.param_groups[0]['lr'],
                                   epoch)
            for m in model.modules():
                if isinstance(m, torch.nn.Conv2d):
                    break
            self.logger.info("Firstly, start to train weights for epoch %d" %
                             (epoch))
            if epoch == 5:
                self.w_optimizer.param_groups[0]['lr'] *= 0.1
            self._training_step(model,
                                train_w_loader,
                                self.w_optimizer,
                                epoch,
                                info_for_logger="_w_step_")
            self.w_scheduler.step()
            top1_avg = self._validate(model, test_loader, epoch)
            if best_top1 < top1_avg:
                best_top1 = top1_avg
                self.logger.info("Best top1 acc by now. Save model")
                save(model, self.path_to_save_model)
                print("Best model saved!")

        for epoch in range(self.train_thetas_from_the_epoch, self.cnt_epochs):
            self.writer.add_scalar('learning_rate/weights',
                                   self.w_optimizer.param_groups[0]['lr'],
                                   epoch)
            self.writer.add_scalar('learning_rate/theta',
                                   self.theta_optimizer.param_groups[0]['lr'],
                                   epoch)

            self.logger.info("Start to train weights for epoch %d" % (epoch))
            if epoch == 10:
                self.w_optimizer.param_groups[0]['lr'] *= 0.1
                self.theta_optimizer.param_groups[0]['lr'] *= 0.1
            if epoch == 20:
                self.w_optimizer.param_groups[0]['lr'] *= 0.1
                self.theta_optimizer.param_groups[0]['lr'] *= 0.1

            self._training_step(model,
                                train_w_loader,
                                self.w_optimizer,
                                epoch,
                                info_for_logger="_w_step_")
            self.w_scheduler.step()

            self.logger.info("Start to train theta for epoch %d" % (epoch))
            self._training_step(model,
                                train_thetas_loader,
                                self.theta_optimizer,
                                epoch,
                                info_for_logger="_theta_step_")

            top1_avg = self._validate(model, test_loader, epoch)
            if best_top1 < top1_avg:
                best_top1 = top1_avg
                self.logger.info("Best top1 acc by now. Save model")
                save(model, self.path_to_save_model)
                print("Best model saved!")

            self.temperature = self.temperature * self.exp_anneal_rate

    def _training_step(self,
                       model,
                       loader,
                       optimizer,
                       epoch,
                       info_for_logger=""):
        model = model.train()
        start_time = time.time()

        for step, (X, y) in enumerate(loader):
            X, y = X.cuda(non_blocking=True), y.cuda(non_blocking=True)
            # X.to(device, non_blocking=True), y.to(device, non_blocking=True)
            N = X.shape[0]

            optimizer.zero_grad()
            latency_to_accumulate = Variable(torch.Tensor([[0.0]]),
                                             requires_grad=True).cuda()
            outs, latency_to_accumulate = model(X, self.temperature,
                                                latency_to_accumulate)
            loss = self.criterion(outs, y, latency_to_accumulate,
                                  self.losses_ce, self.losses_lat, N)
            loss.backward()
            optimizer.step()

            self._intermediate_stats_logging(outs,
                                             y,
                                             loss,
                                             step,
                                             epoch,
                                             N,
                                             len_loader=len(loader),
                                             val_or_train="Train")

        self._epoch_stats_logging(start_time=start_time,
                                  epoch=epoch,
                                  info_for_logger=info_for_logger,
                                  val_or_train='train')
        for avg in [self.top1, self.top3, self.losses]:
            avg.reset()

    def _validate(self, model, loader, epoch):
        model.eval()
        start_time = time.time()

        with torch.no_grad():
            for step, (X, y) in enumerate(loader):
                X, y = X.cuda(), y.cuda()
                N = X.shape[0]

                latency_to_accumulate = torch.Tensor([[0.0]]).cuda()
                outs, latency_to_accumulate = model(X, self.temperature,
                                                    latency_to_accumulate)
                loss = self.criterion(outs, y, latency_to_accumulate,
                                      self.losses_ce, self.losses_lat, N)

                self._intermediate_stats_logging(outs,
                                                 y,
                                                 loss,
                                                 step,
                                                 epoch,
                                                 N,
                                                 len_loader=len(loader),
                                                 val_or_train="Valid")

        top1_avg = self.top1.get_avg()
        self._epoch_stats_logging(start_time=start_time,
                                  epoch=epoch,
                                  val_or_train='val')
        for avg in [self.top1, self.top3, self.losses]:
            avg.reset()
        return top1_avg

    def _epoch_stats_logging(self,
                             start_time,
                             epoch,
                             val_or_train,
                             info_for_logger=''):
        self.writer.add_scalar(
            'train_vs_val/' + val_or_train + '_loss' + info_for_logger,
            self.losses.get_avg(), epoch)
        self.writer.add_scalar(
            'train_vs_val/' + val_or_train + '_top1' + info_for_logger,
            self.top1.get_avg(), epoch)
        self.writer.add_scalar(
            'train_vs_val/' + val_or_train + '_top3' + info_for_logger,
            self.top3.get_avg(), epoch)
        self.writer.add_scalar(
            'train_vs_val/' + val_or_train + '_losses_lat' + info_for_logger,
            self.losses_lat.get_avg(), epoch)
        self.writer.add_scalar(
            'train_vs_val/' + val_or_train + '_losses_ce' + info_for_logger,
            self.losses_ce.get_avg(), epoch)

        top1_avg = self.top1.get_avg()
        self.logger.info(info_for_logger + val_or_train +
                         ": [{:3d}/{}] Final Prec@1 {:.4%} Time {:.2f}".format(
                             epoch + 1, self.cnt_epochs, top1_avg,
                             time.time() - start_time))

    def _intermediate_stats_logging(self, outs, y, loss, step, epoch, N,
                                    len_loader, val_or_train):
        prec1, prec3 = accuracy(outs, y, topk=(1, 5))
        self.losses.update(loss.item(), N)
        self.top1.update(prec1.item(), N)
        self.top3.update(prec3.item(), N)

        if (step > 1
                and step % self.print_freq == 0) or step == len_loader - 1:
            self.logger.info(
                val_or_train + ": [{:3d}/{}] Step {:03d}/{:03d} Loss {:.3f} "
                "Prec@(1,3) ({:.1%}, {:.1%}), ce_loss {:.3f}, lat_loss {:.3f}".
                format(epoch + 1, self.cnt_epochs, step, len_loader -
                       1, self.losses.get_avg(), self.top1.get_avg(),
                       self.top3.get_avg(), self.losses_ce.get_avg(),
                       self.losses_lat.get_avg()))
Exemplo n.º 5
0
class TrainerSupernet:
    def __init__(self, criterion, w_optimizer, theta_optimizer, w_scheduler,
                 logger, writer, experiment):
        self.top1 = AverageMeter()
        self.top3 = AverageMeter()
        self.losses = AverageMeter()
        self.losses_lat = AverageMeter()
        self.losses_ce = AverageMeter()
        self.losses_energy = AverageMeter()  #energy add
        self.logger = logger
        self.writer = writer

        self.criterion = criterion
        self.w_optimizer = w_optimizer
        self.theta_optimizer = theta_optimizer
        self.w_scheduler = w_scheduler
        self.experiment = experiment

        self.temperature = CONFIG_SUPERNET['train_settings'][
            'init_temperature']
        self.exp_anneal_rate = CONFIG_SUPERNET['train_settings'][
            'exp_anneal_rate']  # apply it every epoch
        self.cnt_epochs = CONFIG_SUPERNET['train_settings']['cnt_epochs']
        self.train_thetas_from_the_epoch = CONFIG_SUPERNET['train_settings'][
            'train_thetas_from_the_epoch']
        self.print_freq = CONFIG_SUPERNET['train_settings']['print_freq']
        self.path_to_save_model = CONFIG_SUPERNET['train_settings'][
            'path_to_save_model']

    def train_loop(self, train_w_loader, train_thetas_loader, test_loader,
                   model):
        global n
        best_top1 = 0.0
        best_lat = 10000000
        best_energy = 10000000
        n = 1
        # firstly, train weights only
        for epoch in range(self.train_thetas_from_the_epoch):
            self.writer.add_scalar('learning_rate/weights',
                                   self.w_optimizer.param_groups[0]['lr'],
                                   epoch)

            self.logger.info("Firstly, start to train weights for epoch %d" %
                             (epoch))
            self._training_step(model,
                                train_w_loader,
                                self.w_optimizer,
                                epoch,
                                info_for_logger="_w_step_")
            self.w_scheduler.step()

        for epoch in range(self.train_thetas_from_the_epoch, self.cnt_epochs):
            self.writer.add_scalar('learning_rate/weights',
                                   self.w_optimizer.param_groups[0]['lr'],
                                   epoch)
            self.writer.add_scalar('learning_rate/theta',
                                   self.theta_optimizer.param_groups[0]['lr'],
                                   epoch)

            self.logger.info("Start to train weights for epoch %d" % (epoch))
            self._training_step(model,
                                train_w_loader,
                                self.w_optimizer,
                                epoch,
                                info_for_logger="_w_step_")
            self.w_scheduler.step()

            self.logger.info("Start to train theta for epoch %d" % (epoch))
            self._training_step(model,
                                train_thetas_loader,
                                self.theta_optimizer,
                                epoch,
                                info_for_logger="_theta_step_")

            top1_avg, top3_avg, lat_avg, energy_avg = self._validate(
                model, test_loader, epoch)
            #if best_top1 < top1_avg and lat_avg < best_lat:
            #if best_top1 < top1_avg: #original

            if (best_top1 < top1_avg and lat_avg < best_lat
                    and energy_avg < best_energy) or best_top1 < top1_avg:
                if best_top1 < top1_avg:
                    best_top1 = top1_avg
                    print("Best Acc!!")
                if lat_avg < best_lat:
                    best_lat = lat_avg
                    print("Best Speed!!")
                if energy_avg < best_energy:
                    best_energy = energy_avg
                    print("Best Energy!!")
                self.logger.info("Best top1 acc by now. Save model")
                #print("Over Acc: 0.70")
                #print("Model Number = " + str(n))
                save(model, self.path_to_save_model + str(n) + '.pth')
                #n += 1
            '''
            if (top1_avg >= 0.75) or (top1_avg >= 0.75  and lat_avg < best_lat) or (top1_avg >= 0.75  and energy_avg < best_energy) :
                if lat_avg < best_lat:
                    best_lat = lat_avg
                    print("Best Latency!!")
                if energy_avg < best_energy:
                     best_energy = energy_avg
                     print("Best Energy!!")
                self.logger.info("Best top1 acc by now. Save model")
                print("Over Acc: 0.75")
                #print("Model Number = " + str(n))
                save(model, self.path_to_save_model + str(n) + '.pth')
                #n+=1
            '''
            self.temperature = self.temperature * self.exp_anneal_rate

    def _training_step(self,
                       model,
                       loader,
                       optimizer,
                       epoch,
                       info_for_logger=""):
        model = model.train()
        start_time = time.time()
        global i
        global j
        for step, (X, y) in enumerate(loader):
            #X, y = X.cuda(non_blocking=True), y.cuda(non_blocking=True)
            # X.to(device, non_blocking=True), y.to(device, non_blocking=True)
            X, y = X.to(device, non_blocking=True), y.to(device,
                                                         non_blocking=True)
            N = X.shape[0]
            optimizer.zero_grad()
            latency_to_accumulate = Variable(torch.Tensor([[0.0]]),
                                             requires_grad=True).to(device)
            energy_to_accumulate = Variable(torch.Tensor([[0.0]]),
                                            requires_grad=True).to(
                                                device)  #energy add
            outs, latency_to_accumulate, energy_to_accumulate = model(
                X, self.temperature, latency_to_accumulate,
                energy_to_accumulate)  #energy add
            loss = self.criterion(outs, y, latency_to_accumulate,
                                  energy_to_accumulate, self.losses_ce,
                                  self.losses_lat, self.losses_energy,
                                  N)  #energy add
            loss.backward()
            optimizer.step()

            self._intermediate_stats_logging(outs,
                                             y,
                                             loss,
                                             step,
                                             epoch,
                                             N,
                                             len_loader=len(loader),
                                             val_or_train="Train")
            if info_for_logger == "_w_step_":
                i += 1
            elif info_for_logger == "_theta_step_":
                j += 1
            list1 = ["acc1", "acc3", "ce_loss", "lat_loss", "energy_loss"]
            cnt = 0
            for word in [
                    self.top1, self.top3, self.losses_ce, self.losses_lat,
                    self.losses_energy
            ]:
                word = str(word)
                word = word.replace(' ', '')
                word = word.replace(':', '')
                if info_for_logger == "_w_step_":
                    self.experiment.log_metric("w_train_" + list1[cnt],
                                               float(word),
                                               step=i)
                elif info_for_logger == "_theta_step_":
                    self.experiment.log_metric("theta_train_" + list1[cnt],
                                               float(word),
                                               step=j)
                cnt += 1
            total_loss = loss.to('cpu').detach().numpy().copy()
            total_loss = str(total_loss)
            total_loss = total_loss.replace('[[', '')
            total_loss = total_loss.replace(']]', '')
            if info_for_logger == "_w_step_":
                self.experiment.log_metric("w_train_totalLoss",
                                           float(total_loss),
                                           step=i)
            elif info_for_logger == "_theta_step_":
                self.experiment.log_metric("theta_train_totalLoss",
                                           float(total_loss),
                                           step=j)

        self._epoch_stats_logging(start_time=start_time,
                                  epoch=epoch,
                                  info_for_logger=info_for_logger,
                                  val_or_train='train')
        for avg in [
                self.top1, self.top3, self.losses, self.losses_ce,
                self.losses_lat, self.losses_energy
        ]:
            avg.reset()

    def _validate(self, model, loader, epoch):
        model.eval()
        start_time = time.time()
        global k
        with torch.no_grad():
            for step, (X, y) in enumerate(loader):
                #X, y = X.cuda(), y.cuda()
                X, y = X.to(device), y.to(device)
                N = X.shape[0]
                latency_to_accumulate = torch.Tensor([[0.0]]).to(device)
                energy_to_accumulate = torch.Tensor([[0.0]
                                                     ]).to(device)  #energy add

                outs, latency_to_accumulate, energy_to_accumulate = model(
                    X, self.temperature, latency_to_accumulate,
                    energy_to_accumulate)  #energy add
                loss = self.criterion(outs, y, latency_to_accumulate,
                                      energy_to_accumulate, self.losses_ce,
                                      self.losses_lat, self.losses_energy,
                                      N)  #energy add

                self._intermediate_stats_logging(outs,
                                                 y,
                                                 loss,
                                                 step,
                                                 epoch,
                                                 N,
                                                 len_loader=len(loader),
                                                 val_or_train="Valid")
                k += 1
                list1 = ["acc1", "acc3", "ce_loss", "lat_loss", "energy_loss"]
                cnt = 0
                for word in [
                        self.top1, self.top3, self.losses_ce, self.losses_lat,
                        self.losses_energy
                ]:
                    word = str(word)
                    word = word.replace(' ', '')
                    word = word.replace(':', '')
                    self.experiment.log_metric("val_" + list1[cnt],
                                               float(word),
                                               step=k)
                    cnt += 1
                total_loss = loss.to('cpu').detach().numpy().copy()
                total_loss = str(total_loss)
                total_loss = total_loss.replace('[[', '')
                total_loss = total_loss.replace(']]', '')

                self.experiment.log_metric("val_totalLoss",
                                           float(total_loss),
                                           step=k)
        top1_avg = self.top1.get_avg()
        self._epoch_stats_logging(start_time=start_time,
                                  epoch=epoch,
                                  val_or_train='val')
        for avg in [
                self.top1, self.top3, self.losses, self.losses_ce,
                self.losses_lat, self.losses_energy
        ]:
            avg.reset()
        return top1_avg, self.top3.get_avg(), self.losses_lat.get_avg(
        ), self.losses_energy.get_avg()  #lat, energy追加

    def _epoch_stats_logging(self,
                             start_time,
                             epoch,
                             val_or_train,
                             info_for_logger=''):
        self.writer.add_scalar(
            'train_vs_val/' + val_or_train + '_loss' + info_for_logger,
            self.losses.get_avg(), epoch)
        self.writer.add_scalar(
            'train_vs_val/' + val_or_train + '_top1' + info_for_logger,
            self.top1.get_avg(), epoch)
        self.writer.add_scalar(
            'train_vs_val/' + val_or_train + '_top3' + info_for_logger,
            self.top3.get_avg(), epoch)
        self.writer.add_scalar(
            'train_vs_val/' + val_or_train + '_losses_lat' + info_for_logger,
            self.losses_lat.get_avg(), epoch)
        self.writer.add_scalar(
            'train_vs_val/' + val_or_train + '_losses_ce' + info_for_logger,
            self.losses_ce.get_avg(), epoch)
        self.writer.add_scalar('train_vs_val/' + val_or_train +
                               '_losses_energy' + info_for_logger,
                               self.losses_energy.get_avg(),
                               epoch)  #energy add

        top1_avg = self.top1.get_avg()
        self.logger.info(info_for_logger + val_or_train +
                         ": [{:3d}/{}] Final Prec@1 {:.4%} Time {:.2f}".format(
                             epoch + 1, self.cnt_epochs, top1_avg,
                             time.time() - start_time))

    def _intermediate_stats_logging(self, outs, y, loss, step, epoch, N,
                                    len_loader, val_or_train):
        prec1, prec3 = accuracy(outs, y, topk=(1, 5))
        self.losses.update(loss.item(), N)
        self.top1.update(prec1.item(), N)
        self.top3.update(prec3.item(), N)

        if (step > 1
                and step % self.print_freq == 0) or step == len_loader - 1:
            self.logger.info(
                val_or_train + ": [{:3d}/{}] Step {:03d}/{:03d} Loss {:.3f} "
                "Prec@(1,3) ({:.1%}, {:.1%}), ce_loss {:.3f}, lat_loss {:.3f}, energy_loss {:.3f}"
                .format(epoch + 1, self.cnt_epochs, step, len_loader - 1,
                        self.losses.get_avg(), self.top1.get_avg(),
                        self.top3.get_avg(), self.losses_ce.get_avg(),
                        self.losses_lat.get_avg(),
                        self.losses_energy.get_avg()))  #energy add