def _intermediate_stats_logging(self, outs, y, loss, step, epoch, N, len_loader, val_or_train):
     prec1, prec3 = accuracy(outs, y, topk=(1, 3))
     self.losses.update(loss.item(), N)
     self.top1.update(prec1.item(), N)
     self.top3.update(prec3.item(), N)
     
     if (step > 1 and step % self.print_freq == 0) or step == len_loader - 1:
         self.logger.info(val_or_train+
            ": [{:3d}/{}] Step {:03d}/{:03d} Loss {:.3f} "
            "Prec@(1,3) ({:.1%}, {:.1%})".format(
                epoch + 1, self.cnt_epochs, step, len_loader - 1, self.losses.get_avg(),
                self.top1.get_avg(), self.top3.get_avg()))
    def forward(self, outs, targets, latency, energy, losses_ce, losses_lat, losses_energy, N): #energy add
        ce = self.weight_criterion(outs, targets)
        cal_loss = ce.to('cpu').detach().numpy().copy()
        prec1, _ = accuracy(outs, targets, topk=(1, 3)) #追加
        prec1 = prec1.to('cpu').detach().numpy().copy()
        global device
        
        '''
        rate = prec1 / 0.80
        #print("ce")
        #print(ce)
        if prec1 >= 0.80:
            ce = torch.sub(ce, ce)
        else:
            ce = torch.sub(ce, cal_loss * rate)
            #print(ce)
        '''
        #ce = torch.add(ce, 1.0)

        if torch.isnan(ce) or ce < 0.0: #改良
            ce = torch.tensor(0.0, requires_grad=True)
        
        cal_lat = torch.log(latency ** self.beta) #出力用, 改良
        #energy_cal = torch.log(energy ** self.beta) #出力用, 改良
        energy_cal = torch.log(energy ** self.delta)

        energy_cal = energy_cal.to('cpu').detach().numpy().copy() #計算用
        cal_lat = cal_lat.to('cpu').detach().numpy().copy() #計算用
         
        if energy_cal < 0 or np.isnan(energy_cal): #改良
            energy_cal = np.zeros(1)
        if cal_lat < 0 or np.isnan(cal_lat): #改良
            cal_lat = np.zeros(1)

        lat = torch.log(latency ** self.beta) #original
        #energy = torch.log(energy ** self.beta)
        energy = torch.log(energy ** self.delta)
        
        '''
        #lat = torch.div(lat, 1.7) 
        rate = 5 / cal_lat
        rate = torch.from_numpy(rate).clone()
        rate = rate.to(device)
        tmp = torch.mul(lat, rate)
        #device = torch.device("cuda:3")
        tmp = tmp.to(device)
        if cal_lat > 5:
            lat = torch.sub(lat, tmp)
            lat = torch.add(lat, 1.0)
        else:
            lat = torch.div(lat, 5)
            lat = torch.pow(lat, 0.1)
        #lat = torch.add(lat, 1.0)
        #print(lat)
        '''
        losses_ce.update(ce.item(), N)
        losses_lat.update(cal_lat.item(), N)
        losses_energy.update(energy_cal.item(), N) 
        
        if energy < 0 or torch.isnan(energy):
            energy = torch.tensor(0.0)
        if lat < 0 or torch.isnan(lat):
            lat = torch.tensor(0.0)
        
        #loss = self.alpha * ce * (lat + energy)
        loss = ce + (self.alpha * lat) + (self.gamma * energy)
        #loss = ce * (lat + energy)
        #loss = self.alpha * (ce + lat + energy)
        #loss = ce + lat + energy
        if loss < 0 or torch.isnan(loss):
            loss = torch.tensor(0.0, requires_grad=True)
        
        return loss