def test(self, ): # switch to evaluate mode self.model.eval() ssimes = AverageMeter() psnres = AverageMeter() with torch.no_grad(): for i, batches in enumerate(self.val_loader): inputs = batches['image'].to(self.device) target = batches['target'].to(self.device) mask = batches['mask'].to(self.device) outputs = self.model(inputs) # select the outputs by the giving arch if type(outputs) == type(inputs): output = outputs elif type(outputs[0]) == type([]): output = outputs[0][0] else: output = outputs[0] # recover the image to 255 output = im_to_numpy( torch.clamp(output[0] * 255, min=0.0, max=255.0)).astype(np.uint8) target = im_to_numpy( torch.clamp(target[0] * 255, min=0.0, max=255.0)).astype(np.uint8) skimage.io.imsave( '%s/%s' % (self.args.checkpoint, batches['name'][0]), output) psnr = compare_psnr(target, output) ssim = compare_ssim(target, output, multichannel=True) psnres.update(psnr, inputs.size(0)) ssimes.update(ssim, inputs.size(0)) print("%s:PSNR:%s,SSIM:%s" % (self.args.checkpoint, psnres.avg, ssimes.avg)) print("DONE.\n")
def train(self, epoch): self.current_epoch = epoch if self.args.freeze and epoch > 10: self.model.freeze_weighting_of_rasc() self.optimizer_G = torch.optim.Adam( filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.args.lr, betas=(0.5, 0.999), weight_decay=self.args.weight_decay) batch_time = AverageMeter() data_time = AverageMeter() LoggerLossG = AverageMeter() LoggerLossGGAN = AverageMeter() LoggerLossGL1 = AverageMeter() LoggerLossD = AverageMeter() LoggerLossDreal = AverageMeter() LoggerLossDfake = AverageMeter() lossMask8s = AverageMeter() lossMask4s = AverageMeter() lossMask2s = AverageMeter() # switch to train mode self.model.train() self.discriminator.train() end = time.time() bar = Bar('Processing {} '.format(self.args.arch), max=len(self.train_loader)) for i, (inputs, target) in enumerate(self.train_loader): input_image, mask, m2s, m4s, m8s = inputs current_index = len(self.train_loader) * epoch + i valid = torch.ones((input_image.size(0), self.patch, self.patch), requires_grad=False).cuda() fake = torch.zeros((input_image.size(0), self.patch, self.patch), requires_grad=False).cuda() reverse_mask = 1 - mask if self.args.gpu: input_image = input_image.cuda() mask = mask.cuda() m2s = m2s.cuda() m4s = m4s.cuda() m8s = m8s.cuda() reverse_mask = reverse_mask.cuda() target = target.cuda() valid.cuda() fake.cuda() # --------------- # Train model # -------------- self.optimizer_G.zero_grad() fake_input, mask8s, mask4s, mask2s = self.model( torch.cat((input_image, mask), 1)) pred_fake = self.discriminator(fake_input, input_image) loss_GAN = self.criterion_GAN(pred_fake, valid) loss_pixel = self.criterion_L1(fake_input, target) # fake in # here two choice: mseLoss or NLLLoss masked_loss8s = self.attentionLoss8s(mask8s, m8s) masked_loss4s = self.attentionLoss4s(mask4s, m4s) masked_loss2s = self.attentionLoss2s(mask2s, m2s) loss_G = loss_GAN + 100 * loss_pixel + 90 * masked_loss8s + 90 * masked_loss4s + 90 * masked_loss2s loss_G.backward() self.optimizer_G.step() self.optimizer_D.zero_grad() pred_real = self.discriminator(target, input_image) loss_real = self.criterion_GAN(pred_real, valid) pred_fake = self.discriminator(fake_input.detach(), input_image) loss_fake = self.criterion_GAN(pred_fake, fake) loss_D = 0.5 * (loss_real + loss_fake) loss_D.backward() self.optimizer_D.step() # --------------------- # Logger # --------------------- LoggerLossGGAN.update(loss_GAN.item(), input_image.size(0)) LoggerLossGL1.update(loss_pixel.item(), input_image.size(0)) LoggerLossG.update(loss_G.item(), input_image.size(0)) LoggerLossDfake.update(loss_real.item(), input_image.size(0)) LoggerLossDreal.update(loss_fake.item(), input_image.size(0)) LoggerLossD.update(loss_D.item(), input_image.size(0)) lossMask8s.update(masked_loss8s.item(), input_image.size(0)) lossMask4s.update(masked_loss4s.item(), input_image.size(0)) lossMask2s.update(masked_loss2s.item(), input_image.size(0)) # --------------------- # Visualize # --------------------- if i == 1: self.writer.add_images('train/Goutput', deNorm(fake_input), current_index) self.writer.add_images('train/target', deNorm(target), current_index) self.writer.add_images('train/input', deNorm(input_image), current_index) self.writer.add_images('train/mask', mask.repeat((1, 3, 1, 1)), current_index) self.writer.add_images('train/attention2s', mask2s.repeat(1, 3, 1, 1), current_index) self.writer.add_images('train/attention4s', mask4s.repeat(1, 3, 1, 1), current_index) self.writer.add_images('train/attention8s', mask8s.repeat(1, 3, 1, 1), current_index) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss D: {loss_d:.4f} | Loss G: {loss_g:.4f} | Loss L1: {loss_l1:.6f} '.format( batch=i + 1, size=len(self.train_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss_d=LoggerLossD.avg, loss_g=LoggerLossGGAN.avg, loss_l1=LoggerLossGL1.avg) bar.next() bar.finish() self.writer.add_scalar('train/loss/GAN', LoggerLossGGAN.avg, epoch) self.writer.add_scalar('train/loss/D', LoggerLossD.avg, epoch) self.writer.add_scalar('train/loss/L1', LoggerLossGL1.avg, epoch) self.writer.add_scalar('train/loss/G', LoggerLossG.avg, epoch) self.writer.add_scalar('train/loss/Dreal', LoggerLossDreal.avg, epoch) self.writer.add_scalar('train/loss/Dfake', LoggerLossDfake.avg, epoch) self.writer.add_scalar('train/loss_Mask8s', lossMask8s.avg, epoch) self.writer.add_scalar('train/loss_Mask4s', lossMask4s.avg, epoch) self.writer.add_scalar('train/loss_Mask2s', lossMask2s.avg, epoch)
def validate(self, epoch): self.current_epoch = epoch batch_time = AverageMeter() data_time = AverageMeter() psnres = AverageMeter() ssimes = AverageMeter() lossMask8s = AverageMeter() lossMask4s = AverageMeter() lossMask2s = AverageMeter() # switch to evaluate mode self.model.eval() end = time.time() bar = Bar('Processing {} '.format(self.args.arch), max=len(self.val_loader)) with torch.no_grad(): for i, (inputs, target) in enumerate(self.val_loader): input_image, mask, m2s, m4s, m8s = inputs current_index = len(self.train_loader) * epoch + i valid = torch.ones( (input_image.size(0), self.patch, self.patch), requires_grad=False).cuda() fake = torch.zeros( (input_image.size(0), self.patch, self.patch), requires_grad=False).cuda() reverse_mask = 1 - mask if self.args.gpu: input_image = input_image.cuda() mask = mask.cuda() m2s = m2s.cuda() m4s = m4s.cuda() m8s = m8s.cuda() reverse_mask = reverse_mask.cuda() target = target.cuda() valid.cuda() fake.cuda() # 32,64,128 output, mask8s, mask4s, mask2s = self.model( torch.cat((input_image, mask), 1)) output = deNorm(output) target = deNorm(target) masked_loss8s = self.attentionLoss8s(mask8s, m8s) masked_loss4s = self.attentionLoss4s(mask4s, m4s) masked_loss2s = self.attentionLoss2s(mask2s, m2s) ## psnr and ssim calculator. mse = self.criterion_GAN(output, target) psnr = 10 * log10(1 / mse.item()) ssim = pytorch_ssim.ssim(output, target) psnres.update(psnr, input_image.size(0)) ssimes.update(ssim, input_image.size(0)) lossMask8s.update(masked_loss8s.item(), input_image.size(0)) lossMask4s.update(masked_loss4s.item(), input_image.size(0)) lossMask2s.update(masked_loss2s.item(), input_image.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | SSIM: {ssim:.4f} | PSNR: {psnr:.4f}'.format( batch=i + 1, size=len(self.val_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, ssim=ssimes.avg, psnr=psnres.avg) bar.next() bar.finish() self.writer.add_scalar('val/SSIM', ssimes.avg, epoch) self.writer.add_scalar('val/PSNR', psnres.avg, epoch) self.writer.add_scalar('train/loss_Mask8s', lossMask8s.avg, epoch) self.writer.add_scalar('train/loss_Mask4s', lossMask4s.avg, epoch) self.writer.add_scalar('train/loss_Mask2s', lossMask2s.avg, epoch) self.metric = psnres.avg
def train(self, epoch): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() gradientes = AverageMeter() # switch to train mode self.model.train() end = time.time() bar = Bar('Processing {} '.format(self.args.arch), max=len(self.train_loader)) for i, (inputs, target) in enumerate(self.train_loader): # measure data loading time if self.args.gpu: inputs = inputs.cuda() mask = inputs[:, 3:4, :, :].cuda() target = target.cuda() else: target = target mask = inputs[:, 3:4, :, :] output = self.model(inputs) if i == 1: current_index = len(self.train_loader) * epoch + i self.writer.add_images('train/output', output, current_index) self.writer.add_images('train/target', target, current_index) self.writer.add_images('train/input', inputs[:, 0:3, :, :], current_index) self.writer.add_images('train/mask', mask.repeat(1, 3, 1, 1), current_index) L2_loss = self.loss(output, target) if self.args.gradient_loss: tgx, tgy = image_gradient(inputs[:, 0:3, :, :]) ogx, ogy = image_gradient(output) gradient_loss = self.gradient_loss_y( ogy, tgy) + self.gradient_loss_x(ogx, tgx) else: gradient_loss = 0 total_loss = 1e10 * L2_loss + 1e9 * gradient_loss # compute gradient and do SGD step self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() # measure accuracy and record loss losses.update(1e10 * L2_loss.item(), inputs.size(0)) if self.args.gradient_loss: gradientes.update(1e9 * gradient_loss.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f}'.format( batch=i + 1, size=len(self.train_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss_label=losses.avg) bar.next() bar.finish() self.writer.add_scalar('train/loss_L2', losses.avg, epoch) self.writer.add_scalar('train/loss_gradient', gradientes.avg, epoch)
def validate(self, epoch): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() ssimes = AverageMeter() psnres = AverageMeter() # switch to evaluate mode self.model.eval() end = time.time() bar = Bar('Processing', max=len(self.val_loader)) with torch.no_grad(): for i, (inputs, target) in enumerate(self.val_loader): # measure data loading time if self.args.gpu: inputs = inputs.cuda() # image and bbox target = target.cuda() output = self.model(inputs) mse = self.loss(output, target) L2_loss = 1e10 * mse psnr = 10 * log10(1 / mse.item()) ssim = pytorch_ssim.ssim(output, target) losses.update(L2_loss.item(), inputs.size(0)) psnres.update(psnr, inputs.size(0)) ssimes.update(ssim, inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_L2: {loss_label:.4f} | SSIM: {ssim:.4f} | PSNR: {psnr:.4f}'.format( batch=i + 1, size=len(self.val_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss_label=losses.avg, psnr=psnres.avg, ssim=ssimes.avg, ) bar.next() bar.finish() self.writer.add_scalar('val/loss_L2', losses.avg, epoch) self.writer.add_scalar('val/PSNR', psnres.avg, epoch) self.writer.add_scalar('val/SSIM', ssimes.avg, epoch) self.metric = psnres.avg
def train(self, epoch): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() lossvgg = AverageMeter() # switch to train mode self.model.train() end = time.time() bar = Bar('Processing', max=len(self.train_loader) * self.hl) for _ in range(self.hl): for i, batches in enumerate(self.train_loader): # measure data loading time inputs = batches['image'].to(self.device) target = batches['target'].to(self.device) mask = batches['mask'].to(self.device) current_index = len(self.train_loader) * epoch + i feeded = torch.cat([inputs, mask], dim=1) feeded = feeded.to(self.device) output = self.model(feeded) if self.args.res: output = output + inputs L2_loss = self.loss(output, target) if self.args.style_loss > 0: vgg_loss = self.vggloss(output, target, mask) else: vgg_loss = 0 total_loss = L2_loss + self.args.style_loss * vgg_loss # compute gradient and do SGD step self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() # measure accuracy and record loss losses.update(L2_loss.item(), inputs.size(0)) if self.args.style_loss > 0: lossvgg.update(vgg_loss.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss VGG: {loss_vgg:.4f}'.format( batch=i + 1, size=len(self.train_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss_label=losses.avg, loss_vgg=lossvgg.avg) if current_index % 1000 == 0: print(suffix) if self.args.freq > 0 and current_index % self.args.freq == 0: self.validate(current_index) self.flush() self.save_checkpoint() self.record('train/loss_L2', losses.avg, current_index)
def validate(self, epoch): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() ssimes = AverageMeter() psnres = AverageMeter() # switch to evaluate mode self.model.eval() end = time.time() with torch.no_grad(): for i, batches in enumerate(self.val_loader): inputs = batches['image'].to(self.device) target = batches['target'].to(self.device) mask = batches['mask'].to(self.device) feeded = torch.cat([inputs, mask], dim=1) feeded = feeded.to(self.device) output = self.model(feeded) if self.args.res: output = output + inputs L2_loss = self.loss(output, target) psnr = 10 * log10(1 / L2_loss.item()) ssim = pytorch_ssim.ssim(output, target) losses.update(L2_loss.item(), inputs.size(0)) psnres.update(psnr, inputs.size(0)) ssimes.update(ssim.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() print("Epoches:%s,Losses:%.3f,PSNR:%.3f,SSIM:%.3f" % (epoch + 1, losses.avg, psnres.avg, ssimes.avg)) self.record('val/loss_L2', losses.avg, epoch) self.record('val/PSNR', psnres.avg, epoch) self.record('val/SSIM', ssimes.avg, epoch) self.metric = psnres.avg
def train(self, epoch): self.current_epoch = epoch batch_time = AverageMeter() data_time = AverageMeter() LoggerLossG = AverageMeter() LoggerLossGGAN = AverageMeter() LoggerLossGL1 = AverageMeter() LoggerLossD = AverageMeter() LoggerLossDreal = AverageMeter() LoggerLossDfake = AverageMeter() # switch to train mode self.model.train() self.discriminator.train() end = time.time() bar = Bar('Processing {} '.format(self.args.arch), max=len(self.train_loader)) for i, (inputs, target) in enumerate(self.train_loader): current_index = len(self.train_loader) * epoch + i valid = torch.ones((inputs.size(0), self.patch, self.patch), requires_grad=False).cuda() fake = torch.zeros((inputs.size(0), self.patch, self.patch), requires_grad=False).cuda() input_image = inputs[:, 0:3, :, :] mask = inputs[:, 3:4, :, :] reverse_mask = 1 - mask if self.args.gpu: inputs = inputs.cuda() input_image = input_image.cuda() mask = mask.cuda() reverse_mask = reverse_mask.cuda() target = target.cuda() valid.cuda() fake.cuda() # --------------- # Train model # -------------- self.optimizer_G.zero_grad() fake_input = self.model(inputs) pred_fake = self.discriminator(fake_input, input_image) loss_GAN = self.criterion_GAN(pred_fake, valid) loss_pixel = self.criterion_L1(fake_input, target) # fake in loss_G = loss_GAN + 100 * loss_pixel loss_G.backward() self.optimizer_G.step() self.optimizer_D.zero_grad() pred_real = self.discriminator(target, input_image) loss_real = self.criterion_GAN(pred_real, valid) pred_fake = self.discriminator(fake_input.detach(), input_image) loss_fake = self.criterion_GAN(pred_fake, fake) loss_D = 0.5 * (loss_real + loss_fake) loss_D.backward() self.optimizer_D.step() # --------------------- # Logger # --------------------- LoggerLossGGAN.update(loss_GAN.item(), inputs.size(0)) LoggerLossGL1.update(loss_pixel.item(), inputs.size(0)) LoggerLossG.update(loss_G.item(), inputs.size(0)) LoggerLossDfake.update(loss_real.item(), inputs.size(0)) LoggerLossDreal.update(loss_fake.item(), inputs.size(0)) LoggerLossD.update(loss_D.item(), inputs.size(0)) # --------------------- # Visualize # --------------------- if current_index % (len(self.train_loader) // 10) == 0: self.writer.add_images('train/Goutput', deNorm(fake_input), current_index) self.writer.add_images('train/target', deNorm(target), current_index) self.writer.add_images('train/input', deNorm(inputs[:, 0:3, :, :]), current_index) self.writer.add_images( 'train/mask', inputs[:, 3:4, :, :].repeat((1, 3, 1, 1)), current_index) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss D: {loss_d:.4f} | Loss G: {loss_g:.4f} | Loss L1: {loss_l1:.6f} '.format( batch=i + 1, size=len(self.train_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, loss_d=LoggerLossD.avg, loss_g=LoggerLossGGAN.avg, loss_l1=LoggerLossGL1.avg) bar.next() bar.finish() self.writer.add_scalar('train/loss/GAN', LoggerLossGGAN.avg, epoch) self.writer.add_scalar('train/loss/D', LoggerLossD.avg, epoch) self.writer.add_scalar('train/loss/L1', LoggerLossGL1.avg, epoch) self.writer.add_scalar('train/loss/G', LoggerLossG.avg, epoch) self.writer.add_scalar('train/loss/Dreal', LoggerLossDreal.avg, epoch) self.writer.add_scalar('train/loss/Dfake', LoggerLossDfake.avg, epoch)
def validate(self, epoch): self.current_epoch = epoch batch_time = AverageMeter() data_time = AverageMeter() psnres = AverageMeter() ssimes = AverageMeter() # switch to evaluate mode self.model.eval() end = time.time() bar = Bar('Processing {} '.format(self.args.arch), max=len(self.val_loader)) with torch.no_grad(): for i, (inputs, target) in enumerate(self.val_loader): # measure data loading time if self.args.gpu: inputs = inputs.cuda() target = target.cuda() output = self.model(inputs) output = deNorm(output) target = deNorm(target) ## psnr and ssim calculator. mse = self.criterion_GAN(output, target) psnr = 10 * log10(1 / mse.item()) ssim = pytorch_ssim.ssim(output, target) if i == 10: self.writer.add_images('val/Goutput', output, epoch) self.writer.add_images('val/target', target, epoch) self.writer.add_images('val/input', deNorm(inputs[:, 0:3, :, :]), epoch) self.writer.add_images( 'val/mask', inputs[:, 3:4, :, :].repeat((1, 3, 1, 1)), epoch) psnres.update(psnr, inputs.size(0)) ssimes.update(ssim, inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | SSIM: {ssim:.4f} | PSNR: {psnr:.4f}'.format( batch=i + 1, size=len(self.val_loader), data=data_time.val, bt=batch_time.val, total=bar.elapsed_td, eta=bar.eta_td, ssim=ssimes.avg, psnr=psnres.avg) bar.next() bar.finish() self.writer.add_scalar('val/SSIM', ssimes.avg, epoch) self.writer.add_scalar('val/PSNR', psnres.avg, epoch) self.metric = psnres.avg
def validate(val_loader, model, criterions, args): criterion_classification = criterions[0] criterion_segmentation = criterions[1] losses_label = AverageMeter() acces_label = AverageMeter() losses_mask = AverageMeter() acces_mask = AverageMeter() # switch to evaluate mode model.eval() for i, (inputs, target, label, full_image) in enumerate(val_loader): # measure data loading time if args.gpu: inputs = inputs.cuda() full_image = full_image.cuda() target = target.cuda() label = label.cuda() with torch.no_grad(): input_var = torch.autograd.Variable(inputs) full_image_var = torch.autograd.Variable(full_image) target_var = torch.autograd.Variable(target.long()) label_var = torch.autograd.Variable(label.long()) # compute output output_label,output_mask = model(input_var,full_image_var) loss_label = criterion_classification(output_label, label_var) loss_mask = criterion_segmentation(output_mask, target_var) acc_label = accuracy(output_label.data.cpu(), label.cpu()) acc_mask = accuracy(output_mask.data.cpu(), target.cpu()) # measure accuracy and record loss losses_label.update(loss_label.item(), inputs.size(0)) losses_mask.update(loss_mask.item(), inputs.size(0)) acces_label.update(acc_label, inputs.size(0)) acces_mask.update(acc_mask, inputs.size(0)) return losses_label.avg, acces_label.avg, losses_mask.avg, acces_mask.avg
def train(train_loader, model, criterions, optimizer,args): losses_label = AverageMeter() acces_label = AverageMeter() losses_mask = AverageMeter() acces_mask = AverageMeter() criterion_classification = criterions[0] criterion_segmentation = criterions[1] # switch to train mode model.train() for i, (inputs, target, label, full_image) in enumerate(train_loader): # measure data loading time if args.gpu: inputs = inputs.cuda() full_image = full_image.cuda() target = target.cuda() label = label.cuda() input_var = torch.autograd.Variable(inputs) full_image_var = torch.autograd.Variable(full_image) target_var = torch.autograd.Variable(target.long()) label_var = torch.autograd.Variable(label.long()) # compute output output_label,output_mask = model(input_var,full_image_var) loss_label = criterion_classification(output_label, label_var) loss_mask = criterion_segmentation(output_mask, target_var) loss = loss_label + 10*loss_mask acc_label = accuracy(output_label.data, label) acc_mask = accuracy(output_mask.data, target) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure accuracy and record loss losses_label.update(loss_label.item(), inputs.size(0)) losses_mask.update(loss_mask.item(), inputs.size(0)) acces_label.update(acc_label, inputs.size(0)) acces_mask.update(acc_mask, inputs.size(0)) return losses_label.avg, acces_label.avg, losses_mask.avg, acces_mask.avg