Exemple #1
0
def test(args, loader_test, model_AttentionNet, epoch, root_dir) :
    model_AttentionNet.eval()
    for itr, data in enumerate(loader_test):
        testImg, fileName = data[0], data[1]
        if args.cuda:
            testImg = testImg.cuda()

        with torch.no_grad():
            test_result = model_AttentionNet(testImg)
            test_result_img = train_utils.tensor2im(test_result)
            result_save_dir = root_dir + fileName[0].split('.')[0]+('_epoch_{}_itr_{}.png'.format(epoch, itr))
            train_utils.save_images(test_result_img, result_save_dir)
Exemple #2
0
def test(loader_test, VAN, EN, root_dir):
    VAN.eval()
    EN.eval()

    for itr, data in enumerate(loader_test):
        testImg, img_name = data[0], data[1]
        testImg = testImg.cuda()

        with torch.no_grad():
            visual_attention_map = VAN(testImg)
            enhance_result = EN(testImg, visual_attention_map)
            enhance_result_img = train_utils.tensor2im(enhance_result)
            result_save_dir = root_dir + 'enhance'+ img_name[0].split('.')[0]+('.png')
            train_utils.save_images(enhance_result_img, result_save_dir)
Exemple #3
0
def test(loader_test, visualAttentionNet, root_dir):
    visualAttentionNet.eval()
    for itr, data in enumerate(loader_test):
        testImg, fileName = data[0], data[1]
        testImg = testImg.cuda()

        with torch.no_grad():
            test_attention_result = visualAttentionNet(testImg)

            test_recon_result_img = train_utils.tensor2im(
                test_attention_result)
            norm_input_img = train_utils.tensor2im(testImg +
                                                   test_attention_result)

            recon_save_dir = root_dir + 'visual_attention_map_' + fileName[
                0].split('.')[0] + ('.png')
            recon_save_dir2 = root_dir + 'sum_' + fileName[0].split('.')[0] + (
                '.png')

            train_utils.save_images(test_recon_result_img, recon_save_dir)
            train_utils.save_images(norm_input_img, recon_save_dir2)
Exemple #4
0
def reconstruct(checkpoint_path, data_args, model_args):
    training_loader = video_pipe(batch_size=data_args['batch_size'],
                                 num_threads=data_args['num_threads'],
                                 device_id=data_args['device_id'],
                                 filenames=data_args['training_data_files'],
                                 seed=data_args['seed'])
    training_loader.build()
    training_loader = DALIGenericIterator(training_loader, ['data'])
    checkpoint = load_checkpoint(checkpoint_path, device_id=0)

    model = VqVae(**model_args).to('cuda')
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    unnormalize = NormalizeInverse(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])

    images = next(training_loader)[0]['data']
    b, d, _, _, c = images.size()
    images = rearrange(images, 'b d h w c -> (b d) c h w')
    images = normalize(images.float() / 255.)
    images = rearrange(images, '(b d) c h w -> b (d c) h w', b=b, d=d, c=c)

    vq_loss, images_recon, _ = model(images)
    print('reconstruct error: %6.2f' % vq_loss)
    images, images_recon = map(
        lambda t: rearrange(t, 'b (d c) h w -> (b d) c h w', b=b, d=d, c=c),
        [images, images_recon])
    images_orig, images_recs = train_visualize(unnormalize=unnormalize,
                                               images=images,
                                               n_images=b * d,
                                               image_recs=images_recon)

    save_images(file_name='images_orig.png', image=images_orig)
    save_images(file_name='images_recon.png', image=images_recs)
    def train(self):
        self.model.train()

        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')
        meter_loss = AverageMeter('Loss', ':.4e')
        meter_loss_constr = AverageMeter('Constr', ':6.2f')
        meter_loss_perp = AverageMeter('Perplexity', ':6.2f')
        progress = ProgressMeter(
            self.training_loader.epoch_size()['__Video_0'], [
                batch_time, data_time, meter_loss, meter_loss_constr,
                meter_loss_perp
            ],
            prefix="Steps: [{}]".format(self.num_steps))

        data_iter = DALIGenericIterator(self.training_loader, ['data'],
                                        auto_reset=True)
        end = time.time()

        for i in range(self.start_steps, self.num_steps):
            # measure output loading time
            data_time.update(time.time() - end)

            try:
                images = next(data_iter)[0]['data']
            except StopIteration:
                data_iter.reset()
                images = next(data_iter)[0]['data']

            images = images.to('cuda')
            b, d, _, _, c = images.size()
            images = rearrange(images, 'b d h w c -> (b d) c h w')
            images = self.normalize(images.float() / 255.)
            images = rearrange(images,
                               '(b d) c h w -> b (d c) h w',
                               b=b,
                               d=d,
                               c=c)
            self.optimizer.zero_grad()

            vq_loss, images_recon, perplexity = self.model(images)
            recon_error = F.mse_loss(images_recon, images)
            loss = recon_error + vq_loss
            loss.backward()

            self.optimizer.step()

            meter_loss_constr.update(recon_error.item(), 1)
            meter_loss_perp.update(perplexity.item(), 1)
            meter_loss.update(loss.item(), 1)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 20 == 0:
                progress.display(i)

            if i % 1000 == 0:
                print('saving ...')
                save_checkpoint(
                    self.folder_name, {
                        'steps': i,
                        'state_dict': self.model.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'scheduler': self.scheduler.state_dict()
                    }, 'checkpoint%s.pth.tar' % i)

                self.scheduler.step()
                images, images_recon = map(
                    lambda t: rearrange(
                        t, 'b (d c) h w -> b d c h w', b=b, d=d, c=c),
                    [images, images_recon])
                images_orig, images_recs = train_visualize(
                    unnormalize=self.unnormalize,
                    images=images[0, :self.n_images_save],
                    n_images=self.n_images_save,
                    image_recs=images_recon[0, :self.n_images_save])

                save_images(file_name=os.path.join(self.path_img_orig,
                                                   f'image_{i}.png'),
                            image=images_orig)
                save_images(file_name=os.path.join(self.path_img_recs,
                                                   f'image_{i}.png'),
                            image=images_recs)

                if self.run_wandb:
                    logs = {
                        'iter': i,
                        'loss_recs': meter_loss_constr.val,
                        'loss': meter_loss.val,
                        'lr': self.scheduler.get_last_lr()[0]
                    }
                    self.run_wandb.log(logs)

        print('saving ...')
        save_checkpoint(
            self.folder_name, {
                'steps': self.num_steps,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'scheduler': self.scheduler.state_dict(),
            }, 'checkpoint%s.pth.tar' % self.num_steps)
    def train(self):
        self.model.train()

        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')
        meter_loss = AverageMeter('Loss', ':.4e')
        meter_loss_constr = AverageMeter('Constr', ':6.2f')
        meter_loss_perp = AverageMeter('Perplexity', ':6.2f')
        progress = ProgressMeter(
            len(self.training_loader),
            [batch_time, data_time, meter_loss, meter_loss_constr, meter_loss_perp],
            prefix="Steps: [{}]".format(self.num_steps))

        data_iter = iter(self.training_loader)
        end = time.time()

        for i in range(self.start_steps, self.num_steps):
            # measure output loading time
            data_time.update(time.time() - end)

            try:
                images = next(data_iter)
            except StopIteration:
                data_iter = iter(self.training_loader)
                images = next(data_iter)

            images = images.to('cuda')
            self.optimizer.zero_grad()

            vq_loss, images_recon, perplexity = self.model(images)
            recon_error = F.mse_loss(images_recon, images)
            loss = recon_error + vq_loss
            loss.backward()

            self.optimizer.step()

            meter_loss_constr.update(recon_error.item(), 1)
            meter_loss_perp.update(perplexity.item(), 1)
            meter_loss.update(loss.item(), 1)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 20 == 0:
                progress.display(i)

            if i % 1000 == 0:
                print('saving ...')
                save_checkpoint(self.folder_name, {
                    'steps': i,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'scheduler': self.scheduler.state_dict()
                }, 'checkpoint%s.pth.tar' % i)

                self.scheduler.step()
                images_orig, images_recs = train_visualize(
                    unnormalize=self.unnormalize, images=images[:self.n_images_save], n_images=self.n_images_save,
                    image_recs=images_recon[:self.n_images_save])

                save_images(file_name=os.path.join(self.path_img_orig, f'image_{i}.png'), image=images_orig)
                save_images(file_name=os.path.join(self.path_img_recs, f'image_{i}.png'), image=images_recs)

                if self.run_wandb:
                    logs = {
                        'iter': i,
                        'loss_recs': meter_loss_constr.val,
                        'loss': meter_loss.val,
                        'lr': self.scheduler.get_last_lr()[0]
                    }
                    self.run_wandb.log(logs)

        print('saving ...')
        save_checkpoint(self.folder_name, {
            'steps': self.num_steps,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
        }, 'checkpoint%s.pth.tar' % self.num_steps)