def test(args, model, device, test_loader): model.eval() test_loss = 0 correct = 0 seg_metrics = StreamSegMetrics(args.num_classes) with torch.no_grad(): for i, (data, target) in enumerate(test_loader): data, target = data.to(device), target.to(device, dtype=torch.long) output = model(data) seg_metrics.update( output.max(1)[1].detach().cpu().numpy().astype('uint8'), target.detach().cpu().numpy().astype('uint8')) if i == 0: vp.add_image( 'input', pack_images(((data + 1) / 2).clamp(0, 1.0).cpu().numpy())) vp.add_image( 'target', pack_images(test_loader.dataset.decode_target( target.cpu().numpy()), channel_last=True).astype('uint8')) vp.add_image( 'pred', pack_images(test_loader.dataset.decode_target( output.max(1)[1].detach().cpu().numpy().astype( 'uint8')), channel_last=True).astype('uint8')) results = seg_metrics.get_results() print('\nTest set: Acc= %.6f, mIoU: %.6f\n' % (results['Overall Acc'], results['Mean IoU'])) return results
def warm_up(self, summary, epochs=50): print('-' * 30 + 'Warm up start' + '-' * 30) self.generator.train() for epoch in range(epochs): for i in range(self.opt.iter): z = torch.randn(self.opt.batch_size, self.opt.latent_dim).cuda() self.optimizer_G.zero_grad() gen_imgs = self.generator(z) o_T = self.teacher(gen_imgs) pred = o_T.data.max(1)[1] so_T = torch.nn.functional.softmax(o_T, dim=1) so_T_mean = so_T.mean(dim=0) l_ie = (so_T_mean * torch.log(so_T_mean)).sum() #IE loss l_oh = -(so_T * torch.log(so_T)).sum(dim=1).mean() #one-hot entropy l_bn = 0 #BN loss for mod in self.loss_r_feature_layers: l_bn += mod.G_kd_loss.sum() l_s = self.opt.alpha * (l_ie + l_oh + l_bn) l_s.backward() self.optimizer_G.step() if i == 1: print ("[Epoch %d/%d] [loss_oh: %f] [loss_ie: %f] [loss_BN: %f] " \ % (epoch, epochs,l_oh.item(), l_ie.item(), l_bn.item())) self.scheduler_G.step() saved_img_path = os.path.join(self.opt.saved_img_path + 'warm_up/') if epoch >= epochs - 3: for m in range(np.shape(gen_imgs)[0]): save_dir = saved_img_path + str(epoch) + '/' + str( int(pred[m])) + '/' if not os.path.exists(save_dir): os.makedirs(save_dir) vutils.save_image(gen_imgs[m, :, :, :].data.clone(), save_dir + str(m) + '.png', normalize=True) summary.add_image('warmup/generated', pack_images( denormalize(gen_imgs.data, (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)).clamp( 0, 1).detach().cpu().numpy()), global_step=epoch) summary.add_scalar('warmup_loss_sum', l_s.item(), epoch) if not os.path.exists(self.opt.saved_model_path): os.makedirs(self.opt.saved_model_path) torch.save(self.generator.state_dict(), self.opt.saved_model_path + 'warm_up_gan.pt') print('-' * 30 + 'Warm up end' + '-' * 30)
def test(args, student, generator, device, test_loader, epoch=0): student.eval() generator.eval() test_loss = 0 correct = 0 with torch.no_grad(): for i, (data, target) in enumerate(test_loader): data, target = data.to(device), target.to(device) z = torch.randn((data.shape[0], args.nz, 1, 1), device=data.device, dtype=data.dtype) fake = generator(z) output = student(data) if i == 0: vp.add_image( 'input', pack_images( denormalize(data, (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)).clamp( 0, 1).detach().cpu().numpy())) vp.add_image( 'generated', pack_images( denormalize(fake, (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)).clamp( 0, 1).detach().cpu().numpy())) test_loss += F.cross_entropy( output, target, reduction='sum').item() # sum up batch loss pred = output.argmax( dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print( '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) acc = correct / len(test_loader.dataset) return acc
def test(args, student, teacher, generator, device, test_loader): student.eval() generator.eval() teacher.eval() seg_metrics = StreamSegMetrics(11) if args.save_img: os.makedirs('results/DFAD-camvid', exist_ok=True) img_idx=0 with torch.no_grad(): for i, (data, target) in enumerate(test_loader): data, target = data.to(device), target.to(device) z = torch.randn( (data.shape[0], args.nz, 1, 1), device=data.device, dtype=data.dtype ) fake = generator(z) output = student(data) if args.save_img: t_out = teacher(data) input_imgs = (((data+1)/2)*255).clamp(0,255).detach().cpu().numpy().transpose(0,2,3,1).astype('uint8') colored_preds = test_loader.dataset.decode_target( output.max(1)[1].detach().cpu().numpy() ).astype('uint8') colored_teacher_preds = test_loader.dataset.decode_target( t_out.max(1)[1].detach().cpu().numpy() ).astype('uint8') colored_targets = test_loader.dataset.decode_target( target.detach().cpu().numpy() ).astype('uint8') for _pred, _img, _target, _tpred in zip( colored_preds, input_imgs, colored_targets, colored_teacher_preds ): Image.fromarray( _pred ).save('results/DFAD-camvid/%d_pred.png'%img_idx) Image.fromarray( _img ).save('results/DFAD-camvid/%d_img.png'%img_idx) Image.fromarray( _target ).save('results/DFAD-camvid/%d_target.png'%img_idx) Image.fromarray( _tpred ).save('results/DFAD-camvid/%d_teacher.png'%img_idx) img_idx+=1 if i==0: t_out = teacher(data) t_out_onfake = teacher(fake) s_out_onfake = student(fake) vp.add_image( 'input', pack_images( ((data+1)/2).clamp(0,1).detach().cpu().numpy() ) ) vp.add_image( 'generated', pack_images( ((fake+1)/2).clamp(0,1).detach().cpu().numpy() ) ) vp.add_image( 'target', pack_images( test_loader.dataset.decode_target(target.cpu().numpy()), channel_last=True ).astype('uint8') ) vp.add_image( 'pred', pack_images( test_loader.dataset.decode_target(output.max(1)[1].detach().cpu().numpy().astype('uint8')), channel_last=True ).astype('uint8') ) vp.add_image( 'teacher', pack_images( test_loader.dataset.decode_target(t_out.max(1)[1].detach().cpu().numpy().astype('uint8')), channel_last=True ).astype('uint8') ) vp.add_image( 'teacher-onfake', pack_images( test_loader.dataset.decode_target(t_out_onfake.max(1)[1].detach().cpu().numpy().astype('uint8')), channel_last=True ).astype('uint8') ) vp.add_image( 'student-onfake', pack_images( test_loader.dataset.decode_target(s_out_onfake.max(1)[1].detach().cpu().numpy().astype('uint8')), channel_last=True ).astype('uint8') ) seg_metrics.update(output.max(1)[1].detach().cpu().numpy().astype('uint8'), target.detach().cpu().numpy().astype('uint8')) results = seg_metrics.get_results() print('\nTest set: Acc= %.6f, mIoU: %.6f\n'%(results['Overall Acc'],results['Mean IoU'])) return results
def build(self, summary): print('-' * 30 + 'Main start' + '-' * 30) self.accr_best = 0 self.accr = 0 if self.opt.do_warmup == True: self.warm_up(summary) else: checkpoint = torch.load(self.opt.saved_model_path + 'warm_up_gan.pt') self.generator.load_state_dict(checkpoint) if torch.cuda.is_available(): self.generator = self.generator.cuda() for epoch in range(self.opt.n_epochs): for i in range(self.opt.iter): for _ in range(1): self.student.eval() self.generator.train() z = torch.randn(self.opt.batch_size, self.opt.latent_dim).cuda() self.optimizer_G.zero_grad() gen_imgs = self.generator(z) o_T = self.teacher(gen_imgs) o_S = self.student(gen_imgs) pred = o_T.data.max(1)[1] so_T = torch.nn.functional.softmax(o_T, dim=1) so_T_mean = so_T.mean(dim=0) l_ie = (so_T_mean * torch.log(so_T_mean)).sum() #IE loss l_oh = -(so_T * torch.log(so_T)).sum( dim=1).mean() #one-hot entropy l_bn = 0 #BN loss for mod in self.loss_r_feature_layers: l_bn += mod.G_kd_loss.sum() l_s = l_ie + l_oh + l_bn l_kd_for_G = kd_loss(o_S, o_T) #KD loss g_loss = -l_kd_for_G + self.opt.alpha * l_s g_loss.backward() self.optimizer_G.step() for _ in range(10): self.student.train() self.generator.eval() self.optimizer_S.zero_grad() z = torch.randn(self.opt.batch_size, self.opt.latent_dim).cuda() gen_imgs = self.generator(z) o_T = self.teacher(gen_imgs) o_S = self.student(gen_imgs) l_kd_for_S = kd_loss(o_S, o_T.detach()) #KD loss s_loss = l_kd_for_S s_loss.backward() self.optimizer_S.step() if epoch % 10 == 0 and i == 0: print ("[Epoch %d/%d] [loss_logit: %f] [loss_oh: %f] [loss_ie: %f] [loss_BN: %f] [loss_kd: %f]" \ % (epoch, self.opt.n_epochs,l_l.item(),l_oh.item(), l_ie.item(), l_bn.item(), l_kd_for_S.item())) if epoch % 10 != 0 and i == 0: print("[Epoch %d/%d] [loss_kd: %f]" % (epoch, self.opt.n_epochs, l_kd_for_S.item())) summary.add_image('main/generated', pack_images( denormalize(gen_imgs.data, (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)).clamp( 0, 1).detach().cpu().numpy()), global_step=epoch) summary.add_scalar('main/student_loss', l_kd_for_S.item(), epoch) summary.add_scalar('main/generator_loss', g_loss.item(), epoch) self.scheduler_S.step() self.scheduler_G.step() #save generated image per epoch self.test(summary, epoch) saved_img_path = os.path.join(self.opt.saved_img_path + 'main/') if epoch >= self.opt.n_epochs - 3: for m in range(np.shape(gen_imgs)[0]): save_dir = saved_img_path + str(epoch) + '/' + str( int(pred[m])) + '/' if not os.path.exists(save_dir): os.makedirs(save_dir) vutils.save_image(gen_imgs[m, :, :, :].data.clone(), save_dir + str(m) + '.png', normalize=True) torch.save(self.student.state_dict(), self.opt.saved_model_path + 'student.pt') torch.save(self.generator.state_dict(), self.opt.saved_model_path + 'gan.pt') summary.close() print('-' * 30 + 'Main end' + '-' * 30)