Пример #1
0
def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    seg_metrics = StreamSegMetrics(args.num_classes)

    with torch.no_grad():
        for i, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device, dtype=torch.long)
            output = model(data)
            seg_metrics.update(
                output.max(1)[1].detach().cpu().numpy().astype('uint8'),
                target.detach().cpu().numpy().astype('uint8'))
            if i == 0:
                vp.add_image(
                    'input',
                    pack_images(((data + 1) / 2).clamp(0, 1.0).cpu().numpy()))
                vp.add_image(
                    'target',
                    pack_images(test_loader.dataset.decode_target(
                        target.cpu().numpy()),
                                channel_last=True).astype('uint8'))
                vp.add_image(
                    'pred',
                    pack_images(test_loader.dataset.decode_target(
                        output.max(1)[1].detach().cpu().numpy().astype(
                            'uint8')),
                                channel_last=True).astype('uint8'))

    results = seg_metrics.get_results()

    print('\nTest set: Acc= %.6f, mIoU: %.6f\n' %
          (results['Overall Acc'], results['Mean IoU']))
    return results
Пример #2
0
    def warm_up(self, summary, epochs=50):
        print('-' * 30 + 'Warm up start' + '-' * 30)
        self.generator.train()

        for epoch in range(epochs):
            for i in range(self.opt.iter):
                z = torch.randn(self.opt.batch_size,
                                self.opt.latent_dim).cuda()

                self.optimizer_G.zero_grad()
                gen_imgs = self.generator(z)
                o_T = self.teacher(gen_imgs)
                pred = o_T.data.max(1)[1]
                so_T = torch.nn.functional.softmax(o_T, dim=1)
                so_T_mean = so_T.mean(dim=0)

                l_ie = (so_T_mean * torch.log(so_T_mean)).sum()  #IE loss
                l_oh = -(so_T *
                         torch.log(so_T)).sum(dim=1).mean()  #one-hot entropy

                l_bn = 0  #BN loss
                for mod in self.loss_r_feature_layers:
                    l_bn += mod.G_kd_loss.sum()

                l_s = self.opt.alpha * (l_ie + l_oh + l_bn)

                l_s.backward()
                self.optimizer_G.step()

                if i == 1:
                    print ("[Epoch %d/%d]  [loss_oh: %f] [loss_ie: %f] [loss_BN: %f] " \
                % (epoch, epochs,l_oh.item(), l_ie.item(), l_bn.item()))
            self.scheduler_G.step()
            saved_img_path = os.path.join(self.opt.saved_img_path + 'warm_up/')

            if epoch >= epochs - 3:
                for m in range(np.shape(gen_imgs)[0]):
                    save_dir = saved_img_path + str(epoch) + '/' + str(
                        int(pred[m])) + '/'
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir)
                    vutils.save_image(gen_imgs[m, :, :, :].data.clone(),
                                      save_dir + str(m) + '.png',
                                      normalize=True)

            summary.add_image('warmup/generated',
                              pack_images(
                                  denormalize(gen_imgs.data,
                                              (0.4914, 0.4822, 0.4465),
                                              (0.2023, 0.1994, 0.2010)).clamp(
                                                  0,
                                                  1).detach().cpu().numpy()),
                              global_step=epoch)
            summary.add_scalar('warmup_loss_sum', l_s.item(), epoch)
        if not os.path.exists(self.opt.saved_model_path):
            os.makedirs(self.opt.saved_model_path)
        torch.save(self.generator.state_dict(),
                   self.opt.saved_model_path + 'warm_up_gan.pt')
        print('-' * 30 + 'Warm up end' + '-' * 30)
Пример #3
0
def test(args, student, generator, device, test_loader, epoch=0):
    student.eval()
    generator.eval()

    test_loss = 0
    correct = 0
    with torch.no_grad():
        for i, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)

            z = torch.randn((data.shape[0], args.nz, 1, 1),
                            device=data.device,
                            dtype=data.dtype)
            fake = generator(z)
            output = student(data)
            if i == 0:
                vp.add_image(
                    'input',
                    pack_images(
                        denormalize(data, (0.4914, 0.4822, 0.4465),
                                    (0.2023, 0.1994, 0.2010)).clamp(
                                        0, 1).detach().cpu().numpy()))
                vp.add_image(
                    'generated',
                    pack_images(
                        denormalize(fake, (0.4914, 0.4822, 0.4465),
                                    (0.2023, 0.1994, 0.2010)).clamp(
                                        0, 1).detach().cpu().numpy()))

            test_loss += F.cross_entropy(
                output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(
                dim=1,
                keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))
    acc = correct / len(test_loader.dataset)
    return acc
Пример #4
0
def test(args, student, teacher, generator, device, test_loader):
    student.eval()
    generator.eval()
    teacher.eval()

    seg_metrics = StreamSegMetrics(11)
    if args.save_img:
        os.makedirs('results/DFAD-camvid', exist_ok=True)
    img_idx=0
    with torch.no_grad():
        for i, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)

            z = torch.randn( (data.shape[0], args.nz, 1, 1), device=data.device, dtype=data.dtype )
            fake = generator(z)
            output = student(data)

            if args.save_img:
                t_out = teacher(data)
                
                input_imgs = (((data+1)/2)*255).clamp(0,255).detach().cpu().numpy().transpose(0,2,3,1).astype('uint8')
                colored_preds = test_loader.dataset.decode_target( output.max(1)[1].detach().cpu().numpy() ).astype('uint8')
                colored_teacher_preds = test_loader.dataset.decode_target( t_out.max(1)[1].detach().cpu().numpy() ).astype('uint8')
                colored_targets = test_loader.dataset.decode_target( target.detach().cpu().numpy() ).astype('uint8')
                for _pred, _img, _target, _tpred in zip( colored_preds, input_imgs, colored_targets, colored_teacher_preds  ):
                    Image.fromarray( _pred ).save('results/DFAD-camvid/%d_pred.png'%img_idx)
                    Image.fromarray( _img ).save('results/DFAD-camvid/%d_img.png'%img_idx)
                    Image.fromarray( _target ).save('results/DFAD-camvid/%d_target.png'%img_idx)
                    Image.fromarray( _tpred ).save('results/DFAD-camvid/%d_teacher.png'%img_idx)
                    img_idx+=1
            
            if i==0:
                t_out = teacher(data)
                t_out_onfake = teacher(fake)
                s_out_onfake = student(fake)
                vp.add_image( 'input', pack_images( ((data+1)/2).clamp(0,1).detach().cpu().numpy() ) )
                vp.add_image( 'generated', pack_images( ((fake+1)/2).clamp(0,1).detach().cpu().numpy() ) )
                vp.add_image( 'target', pack_images( test_loader.dataset.decode_target(target.cpu().numpy()), channel_last=True ).astype('uint8') )
                vp.add_image( 'pred',   pack_images( test_loader.dataset.decode_target(output.max(1)[1].detach().cpu().numpy().astype('uint8')), channel_last=True ).astype('uint8') )
                vp.add_image( 'teacher',   pack_images( test_loader.dataset.decode_target(t_out.max(1)[1].detach().cpu().numpy().astype('uint8')), channel_last=True ).astype('uint8') )
                vp.add_image( 'teacher-onfake',   pack_images( test_loader.dataset.decode_target(t_out_onfake.max(1)[1].detach().cpu().numpy().astype('uint8')), channel_last=True ).astype('uint8') )
                vp.add_image( 'student-onfake',   pack_images( test_loader.dataset.decode_target(s_out_onfake.max(1)[1].detach().cpu().numpy().astype('uint8')), channel_last=True ).astype('uint8') )
            seg_metrics.update(output.max(1)[1].detach().cpu().numpy().astype('uint8'), target.detach().cpu().numpy().astype('uint8'))

    results = seg_metrics.get_results()
    print('\nTest set: Acc= %.6f, mIoU: %.6f\n'%(results['Overall Acc'],results['Mean IoU']))
    return results
Пример #5
0
    def build(self, summary):
        print('-' * 30 + 'Main start' + '-' * 30)

        self.accr_best = 0
        self.accr = 0
        if self.opt.do_warmup == True:
            self.warm_up(summary)
        else:
            checkpoint = torch.load(self.opt.saved_model_path +
                                    'warm_up_gan.pt')
            self.generator.load_state_dict(checkpoint)
            if torch.cuda.is_available():
                self.generator = self.generator.cuda()

        for epoch in range(self.opt.n_epochs):
            for i in range(self.opt.iter):

                for _ in range(1):
                    self.student.eval()
                    self.generator.train()

                    z = torch.randn(self.opt.batch_size,
                                    self.opt.latent_dim).cuda()
                    self.optimizer_G.zero_grad()

                    gen_imgs = self.generator(z)
                    o_T = self.teacher(gen_imgs)
                    o_S = self.student(gen_imgs)
                    pred = o_T.data.max(1)[1]
                    so_T = torch.nn.functional.softmax(o_T, dim=1)
                    so_T_mean = so_T.mean(dim=0)

                    l_ie = (so_T_mean * torch.log(so_T_mean)).sum()  #IE loss

                    l_oh = -(so_T * torch.log(so_T)).sum(
                        dim=1).mean()  #one-hot entropy

                    l_bn = 0  #BN loss
                    for mod in self.loss_r_feature_layers:
                        l_bn += mod.G_kd_loss.sum()

                    l_s = l_ie + l_oh + l_bn

                    l_kd_for_G = kd_loss(o_S, o_T)  #KD loss

                    g_loss = -l_kd_for_G + self.opt.alpha * l_s
                    g_loss.backward()
                    self.optimizer_G.step()

                for _ in range(10):
                    self.student.train()
                    self.generator.eval()
                    self.optimizer_S.zero_grad()

                    z = torch.randn(self.opt.batch_size,
                                    self.opt.latent_dim).cuda()

                    gen_imgs = self.generator(z)
                    o_T = self.teacher(gen_imgs)
                    o_S = self.student(gen_imgs)

                    l_kd_for_S = kd_loss(o_S, o_T.detach())  #KD loss
                    s_loss = l_kd_for_S
                    s_loss.backward()
                    self.optimizer_S.step()

                if epoch % 10 == 0 and i == 0:
                    print ("[Epoch %d/%d] [loss_logit: %f] [loss_oh: %f] [loss_ie: %f] [loss_BN: %f] [loss_kd: %f]" \
                % (epoch, self.opt.n_epochs,l_l.item(),l_oh.item(), l_ie.item(), l_bn.item(), l_kd_for_S.item()))
                if epoch % 10 != 0 and i == 0:
                    print("[Epoch %d/%d] [loss_kd: %f]" %
                          (epoch, self.opt.n_epochs, l_kd_for_S.item()))

            summary.add_image('main/generated',
                              pack_images(
                                  denormalize(gen_imgs.data,
                                              (0.4914, 0.4822, 0.4465),
                                              (0.2023, 0.1994, 0.2010)).clamp(
                                                  0,
                                                  1).detach().cpu().numpy()),
                              global_step=epoch)
            summary.add_scalar('main/student_loss', l_kd_for_S.item(), epoch)
            summary.add_scalar('main/generator_loss', g_loss.item(), epoch)
            self.scheduler_S.step()
            self.scheduler_G.step()

            #save generated image per epoch
            self.test(summary, epoch)
            saved_img_path = os.path.join(self.opt.saved_img_path + 'main/')
            if epoch >= self.opt.n_epochs - 3:
                for m in range(np.shape(gen_imgs)[0]):
                    save_dir = saved_img_path + str(epoch) + '/' + str(
                        int(pred[m])) + '/'
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir)
                    vutils.save_image(gen_imgs[m, :, :, :].data.clone(),
                                      save_dir + str(m) + '.png',
                                      normalize=True)

        torch.save(self.student.state_dict(),
                   self.opt.saved_model_path + 'student.pt')
        torch.save(self.generator.state_dict(),
                   self.opt.saved_model_path + 'gan.pt')
        summary.close()
        print('-' * 30 + 'Main end' + '-' * 30)