def train(self):
        self.model.train()

        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')
        meter_loss = AverageMeter('Loss', ':.4e')
        meter_loss_constr = AverageMeter('Constr', ':6.2f')
        meter_loss_perp = AverageMeter('Perplexity', ':6.2f')
        progress = ProgressMeter(
            self.training_loader.epoch_size()['__Video_0'], [
                batch_time, data_time, meter_loss, meter_loss_constr,
                meter_loss_perp
            ],
            prefix="Steps: [{}]".format(self.num_steps))

        data_iter = DALIGenericIterator(self.training_loader, ['data'],
                                        auto_reset=True)
        end = time.time()

        for i in range(self.start_steps, self.num_steps):
            # measure output loading time
            data_time.update(time.time() - end)

            try:
                images = next(data_iter)[0]['data']
            except StopIteration:
                data_iter.reset()
                images = next(data_iter)[0]['data']

            images = images.to('cuda')
            b, d, _, _, c = images.size()
            images = rearrange(images, 'b d h w c -> (b d) c h w')
            images = self.normalize(images.float() / 255.)
            images = rearrange(images,
                               '(b d) c h w -> b (d c) h w',
                               b=b,
                               d=d,
                               c=c)
            self.optimizer.zero_grad()

            vq_loss, images_recon, perplexity = self.model(images)
            recon_error = F.mse_loss(images_recon, images)
            loss = recon_error + vq_loss
            loss.backward()

            self.optimizer.step()

            meter_loss_constr.update(recon_error.item(), 1)
            meter_loss_perp.update(perplexity.item(), 1)
            meter_loss.update(loss.item(), 1)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 20 == 0:
                progress.display(i)

            if i % 1000 == 0:
                print('saving ...')
                save_checkpoint(
                    self.folder_name, {
                        'steps': i,
                        'state_dict': self.model.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'scheduler': self.scheduler.state_dict()
                    }, 'checkpoint%s.pth.tar' % i)

                self.scheduler.step()
                images, images_recon = map(
                    lambda t: rearrange(
                        t, 'b (d c) h w -> b d c h w', b=b, d=d, c=c),
                    [images, images_recon])
                images_orig, images_recs = train_visualize(
                    unnormalize=self.unnormalize,
                    images=images[0, :self.n_images_save],
                    n_images=self.n_images_save,
                    image_recs=images_recon[0, :self.n_images_save])

                save_images(file_name=os.path.join(self.path_img_orig,
                                                   f'image_{i}.png'),
                            image=images_orig)
                save_images(file_name=os.path.join(self.path_img_recs,
                                                   f'image_{i}.png'),
                            image=images_recs)

                if self.run_wandb:
                    logs = {
                        'iter': i,
                        'loss_recs': meter_loss_constr.val,
                        'loss': meter_loss.val,
                        'lr': self.scheduler.get_last_lr()[0]
                    }
                    self.run_wandb.log(logs)

        print('saving ...')
        save_checkpoint(
            self.folder_name, {
                'steps': self.num_steps,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'scheduler': self.scheduler.state_dict(),
            }, 'checkpoint%s.pth.tar' % self.num_steps)
    def train(self):
        self.model.train()

        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')
        meter_loss = AverageMeter('Loss', ':.4e')
        meter_loss_constr = AverageMeter('Constr', ':6.2f')
        meter_loss_perp = AverageMeter('Perplexity', ':6.2f')
        progress = ProgressMeter(
            len(self.training_loader),
            [batch_time, data_time, meter_loss, meter_loss_constr, meter_loss_perp],
            prefix="Steps: [{}]".format(self.num_steps))

        data_iter = iter(self.training_loader)
        end = time.time()

        for i in range(self.start_steps, self.num_steps):
            # measure output loading time
            data_time.update(time.time() - end)

            try:
                images = next(data_iter)
            except StopIteration:
                data_iter = iter(self.training_loader)
                images = next(data_iter)

            images = images.to('cuda')
            self.optimizer.zero_grad()

            vq_loss, images_recon, perplexity = self.model(images)
            recon_error = F.mse_loss(images_recon, images)
            loss = recon_error + vq_loss
            loss.backward()

            self.optimizer.step()

            meter_loss_constr.update(recon_error.item(), 1)
            meter_loss_perp.update(perplexity.item(), 1)
            meter_loss.update(loss.item(), 1)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 20 == 0:
                progress.display(i)

            if i % 1000 == 0:
                print('saving ...')
                save_checkpoint(self.folder_name, {
                    'steps': i,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'scheduler': self.scheduler.state_dict()
                }, 'checkpoint%s.pth.tar' % i)

                self.scheduler.step()
                images_orig, images_recs = train_visualize(
                    unnormalize=self.unnormalize, images=images[:self.n_images_save], n_images=self.n_images_save,
                    image_recs=images_recon[:self.n_images_save])

                save_images(file_name=os.path.join(self.path_img_orig, f'image_{i}.png'), image=images_orig)
                save_images(file_name=os.path.join(self.path_img_recs, f'image_{i}.png'), image=images_recs)

                if self.run_wandb:
                    logs = {
                        'iter': i,
                        'loss_recs': meter_loss_constr.val,
                        'loss': meter_loss.val,
                        'lr': self.scheduler.get_last_lr()[0]
                    }
                    self.run_wandb.log(logs)

        print('saving ...')
        save_checkpoint(self.folder_name, {
            'steps': self.num_steps,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
        }, 'checkpoint%s.pth.tar' % self.num_steps)
예제 #3
0
def main(args) :
    """Main Function : Data Loading -> Model Building -> Set Optimization -> Training"""
    # Setting Important Arguments #
    args.cuda = True
    args.epochs = 200
    args.lr = 1e-5
    args.batch_size = 8

    # Setting Important Path #
    train_data_root = 'D:\data\DALE/TRAIN/'
    model_save_root_dir = 'D:\Pytorch_code\DALE/checkpoint/DALE_VAN/'
    model_root = '../checkpoint/DALE/'

    # Setting Important Traning Variable #
    VISUALIZATION_STEP = 10
    SAVE_STEP = 1

    print("DALE => Data Loading")

    train_data = dataset_DALE.DALETrain(train_data_root, args)
    loader_train = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)

    print("DALE => Model Building")
    VisualAttentionNet =  VisualAttentionNetwork.VisualAttentionNetwork()

    print("DALE => Set Optimization")
    optG = torch.optim.Adam(list(VisualAttentionNet.parameters()), lr=args.lr, betas=(0.5, 0.999))

    scheduler = lr_scheduler.ExponentialLR(optG, gamma=0.99)


    print("DALE => Setting GPU")
    if args.cuda:
        print("DALE => Use GPU")
        VisualAttentionNet = VisualAttentionNet.cuda()
    print("DALE => Training")

    loss_step = 0

    for epoch in range(1, args.epochs):

        VisualAttentionNet.train()

        for itr, data in enumerate(loader_train):
            low_light_img, ground_truth_img, gt_Attention_img, file_name = data[0], data[1], data[2], data[3]
            if args.cuda:
                low_light_img = low_light_img.cuda()
                ground_truth_img = ground_truth_img.cuda()
                gt_Attention_img = gt_Attention_img.cuda()

            optG.zero_grad()

            attention_result = VisualAttentionNet(low_light_img)

            mse_loss = L1_loss(attention_result, gt_Attention_img)
            p_loss = Perceptual_loss(attention_result, gt_Attention_img) * 10

            total_loss = p_loss + mse_loss

            total_loss.backward()
            optG.step()

            if epoch > 100 and itr==0:
                scheduler.step()
                print(scheduler.get_last_lr())

            if itr != 0 and itr % VISUALIZATION_STEP == 0:
                print("Epoch[{}/{}]({}/{}): "
                      "mse_loss : {:.6f}, "
                      "p_loss : {:.6f}"\
                      .format(epoch, args.epochs, itr, len(loader_train), mse_loss, p_loss))

                # VISDOM LOSS GRAPH #
                loss_dict = {
                    'mse_loss': mse_loss.item(),
                    'p_loss': p_loss.item(),
                }

                visdom_loss(visdom, loss_step, loss_dict)

                # VISDOM VISUALIZATION # -> tensor to numpy => list ('title_name', img_tensor)
                with torch.no_grad():
                    val_image = Image.open('../validation/15.jpg')

                    transform = transforms.Compose([
                        transforms.ToTensor(),
                    ])

                    val_image = transform((val_image)).unsqueeze(0)

                    val_image = val_image.cuda()

                    val_attention = VisualAttentionNet.eval()(val_image)

                img_list = OrderedDict(
                    [('input', low_light_img),
                     ('attention_output', attention_result),
                     ('gt_Attention_img', gt_Attention_img),
                     ('batch_sum', attention_result+low_light_img),
                     ('ground_truth', ground_truth_img),
                     ('val_attention', val_attention),
                     ('val_sum', val_image+val_attention)
                     ])

                visdom_image(img_dict=img_list, window=10)
                loss_step = loss_step + 1

        print("DALE => Testing")
        if epoch % SAVE_STEP == 0:
            train_utils.save_checkpoint(VisualAttentionNet, epoch, model_save_root_dir)
예제 #4
0
def main(args):
    args.cuda = True
    args.epochs = 200
    args.lr = 1e-5
    args.batch_size = 4

    # Setting Important Path #
    train_data_root = 'D:\data\DALE/TRAIN/'
    model_save_root_dir = '../checkpoint/DALE/'
    model_root = '../checkpoint/'

    # Setting Important Traning Variable #
    VISUALIZATION_STEP = 50
    SAVE_STEP = 1

    print("DALE => Data Loading")

    train_data = dataset_DALE.DALETrain(train_data_root, args)
    loader_train = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=True)

    print("DALE => Model Building")
    VisualAttentionNet = VisualAttentionNetwork.VisualAttentionNetwork()

    state_dict = torch.load(model_root + 'VAN.pth')
    VisualAttentionNet.load_state_dict(state_dict)

    EnhanceNet = EnhancementNet.EnhancementNet()

    print("DALE => Set Optimization")
    optG = torch.optim.Adam(list(EnhanceNet.parameters()),
                            lr=args.lr,
                            betas=(0.5, 0.999))

    scheduler = lr_scheduler.ExponentialLR(optG, gamma=0.99)

    model_EnhanceNet_parameters = filter(lambda p: p.requires_grad,
                                         EnhanceNet.parameters())

    params1 = sum([np.prod(p.size()) for p in model_EnhanceNet_parameters])

    print("Parameters | ", params1)

    print("DALE => Setting GPU")
    if args.cuda:
        print("DALE => Use GPU")
        VisualAttentionNet = VisualAttentionNet.cuda()
        EnhanceNet = EnhanceNet.cuda()
    print("DALE => Training")

    loss_step = 0

    for epoch in range(1, args.epochs):

        EnhanceNet.train()

        for itr, data in enumerate(loader_train):
            low_light_img, ground_truth_img, gt_Attention_img, file_name = data[
                0], data[1], data[2], data[3]
            if args.cuda:
                low_light_img = low_light_img.cuda()
                ground_truth_img = ground_truth_img.cuda()
                gt_Attention_img = gt_Attention_img.cuda()

            optG.zero_grad()

            attention_result = VisualAttentionNet(low_light_img)
            enhance_result = EnhanceNet(low_light_img,
                                        attention_result.detach())

            mse_loss = L2_loss(enhance_result, ground_truth_img)
            p_loss = Perceptual_loss(enhance_result, ground_truth_img) * 50
            tv_loss = TvLoss(enhance_result) * 20

            total_loss = p_loss + mse_loss + tv_loss

            total_loss.backward()
            optG.step()

            if epoch > 100 and itr == 0:
                scheduler.step()
                print(scheduler.get_last_lr())

            if itr != 0 and itr % VISUALIZATION_STEP == 0:
                print("Epoch[{}/{}]({}/{}): "
                      "mse_loss : {:.6f}, "
                      "tv_loss : {:.6f}, "
                      "p_loss : {:.6f}"\
                      .format(epoch, args.epochs, itr, len(loader_train), mse_loss, tv_loss, p_loss))
                # VISDOM LOSS GRAPH #

                loss_dict = {
                    'mse_loss': mse_loss.item(),
                    'tv_loss': tv_loss.item(),
                    'p_loss': p_loss.item(),
                }

                visdom_loss(visdom, loss_step, loss_dict)

                # VISDOM VISUALIZATION # -> tensor to numpy => list ('title_name', img_tensor)
                with torch.no_grad():
                    val_image = Image.open('../validation/15.jpg')

                    transform = transforms.Compose([
                        transforms.ToTensor(),
                        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                    ])

                    val_image = transform((val_image)).unsqueeze(0)

                    val_image = val_image.cuda()

                    val_attention = VisualAttentionNet.eval()(val_image)
                    val_result = EnhanceNet.eval()(val_image, val_attention)

                img_list = OrderedDict([('input', low_light_img),
                                        ('output', enhance_result),
                                        ('attention_output', attention_result),
                                        ('gt_Attention_img', gt_Attention_img),
                                        ('batch_sum',
                                         attention_result + low_light_img),
                                        ('ground_truth', ground_truth_img),
                                        ('val_result', val_result)])

                visdom_image(img_dict=img_list, window=10)

                loss_step = loss_step + 1

        print("DALE => Testing")
        if epoch % SAVE_STEP == 0:
            train_utils.save_checkpoint(EnhanceNet, epoch, model_save_root_dir)
예제 #5
0
def main(args):
    """Main Function : Data Loading -> Model Building -> Set Optimization -> Training"""
    # Setting Important Arguments #
    args.cuda = True
    args.epochs = 200
    args.lr = 1e-5
    args.batch_size = 5
    # Setting Important Path #
    train_data_root = 'D:\data\DALE/'
    model_save_root_dir = 'D:\Pytorch_code\DALE/checkpoint/'
    model_root = '../checkpoint/DALEGAN/'

    # Setting Important Traning Variable #
    VISUALIZATION_STEP = 50
    SAVE_STEP = 1

    print("DALE => Data Loading")

    train_data = dataset_DALE.DALETrainGlobal(train_data_root, args)
    loader_train = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=True)

    print("DALE => Model Building")
    VAN = VisualAttentionNetwork.AttentionNet2()
    state_dict1 = torch.load(model_root + 'visual_attention_network_model.pth')
    VAN.load_state_dict(state_dict1)

    EnhanceNetG = EnhancementNet.EnhancementNet()
    EnhanceNetD = EnhancementNet.Discriminator()

    state_dict2 = torch.load(model_root + 'enhance_GAN.pth')
    EnhanceNetG.load_state_dict(state_dict2)

    EnhancementNet_parameters = filter(lambda p: p.requires_grad,
                                       EnhanceNetG.parameters())

    params1 = sum([np.prod(p.size()) for p in EnhancementNet_parameters])

    print("Parameters | Discriminator ", params1)

    discriminator_parameters = filter(lambda p: p.requires_grad,
                                      EnhanceNetD.parameters())
    params = sum([np.prod(p.size()) for p in discriminator_parameters])

    print("Parameters | Discriminator ", params)

    print("DALE => Set Optimization")
    optG = torch.optim.Adam(list(EnhanceNetG.parameters()),
                            lr=args.lr,
                            betas=(0.5, 0.999))
    optD = torch.optim.Adam(list(EnhanceNetD.parameters()),
                            lr=args.lr,
                            betas=(0.5, 0.999),
                            weight_decay=0)

    print("DALE => Setting GPU")
    if args.cuda:
        print("DALE => Use GPU")
        VAN = VAN.cuda()
        EnhanceNetG = EnhanceNetG.cuda()
        EnhanceNetD = EnhanceNetD.cuda()

    print("DALE => Training")

    loss_step = 0

    for epoch in range(args.epochs):

        EnhanceNetG.train()
        EnhanceNetD.train()
        for itr, data in enumerate(loader_train):
            low_light_img, ground_truth_img, gt_Attention_img, file_name = data[
                0], data[1], data[2], data[3]
            if args.cuda:
                low_light_img = low_light_img.cuda()
                ground_truth_img = ground_truth_img.cuda()
                gt_Attention_img = gt_Attention_img.cuda()

            optD.zero_grad()

            attention_result = VAN(low_light_img)
            enhance_result = EnhanceNetG(low_light_img,
                                         attention_result).detach()

            loss_D = -torch.mean(EnhanceNetD(ground_truth_img)) \
                     + torch.mean(EnhanceNetD(enhance_result))

            loss_D.backward()
            optD.step()

            for p in EnhanceNetD.parameters():
                p.data.clamp_(-0.01, 0.01)

            if itr % 5 == 0:

                optG.zero_grad()
                enhance_result = EnhanceNetG(low_light_img, attention_result)
                loss_G = -torch.mean(EnhanceNetG(enhance_result)) * 0.5

                e_loss = L2_loss(enhance_result, ground_truth_img)
                p_loss = Perceptual_loss(enhance_result, ground_truth_img) * 10
                tv_loss = TvLoss(enhance_result) * 5

                total_loss = p_loss + e_loss + tv_loss + loss_G

                total_loss.backward()
                optG.step()

            if itr != 0 and itr % VISUALIZATION_STEP == 0:
                print("Epoch[{}/{}]({}/{}): "
                      "e_loss : {:.6f}, "
                      "tv_loss : {:.6f}, "
                      "p_loss : {:.6f}"\
                      .format(epoch, args.epochs, itr, len(loader_train), e_loss,tv_loss, p_loss))

                # VISDOM LOSS GRAPH #

                loss_dict = {
                    'e_loss': e_loss.item(),
                    'tv_loss': tv_loss.item(),
                    'p_loss': p_loss.item(),
                    'g_loss': loss_G.item(),
                    'd_loss': loss_D.item()
                    # 'recon_loss': recon_loss.item()
                }

                visdom_loss(visdom, loss_step, loss_dict)

                # VISDOM VISUALIZATION # -> tensor to numpy => list ('title_name', img_tensor)

                with torch.no_grad():
                    val_image = Image.open('../validation/15.jpg')

                    transform = transforms.Compose([
                        transforms.ToTensor(),
                    ])

                    val_image = transform((val_image)).unsqueeze(0)

                    val_image = val_image.cuda()
                    val_attention = VAN.eval()(val_image)
                    val_result = EnhanceNetG.eval()(val_image, val_attention)

                img_list = OrderedDict([('input', low_light_img),
                                        ('output', enhance_result),
                                        ('attention_output', attention_result),
                                        ('gt_Attention_img', gt_Attention_img),
                                        ('ground_truth', ground_truth_img),
                                        ('val_result', val_result),
                                        ('val_sum', val_attention + val_image)
                                        ])

                visdom_image(img_dict=img_list, window=10)

                loss_step = loss_step + 1

        print("DALE => Testing")

        if epoch % SAVE_STEP == 0:
            train_utils.save_checkpoint(EnhanceNetG, epoch,
                                        model_save_root_dir + 'DALEGAN/')
            train_utils.save_checkpoint(
                EnhanceNetD, epoch,
                model_save_root_dir + 'DALE_Discriminator/')