def train(self): self.model.train() batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') meter_loss = AverageMeter('Loss', ':.4e') meter_loss_constr = AverageMeter('Constr', ':6.2f') meter_loss_perp = AverageMeter('Perplexity', ':6.2f') progress = ProgressMeter( self.training_loader.epoch_size()['__Video_0'], [ batch_time, data_time, meter_loss, meter_loss_constr, meter_loss_perp ], prefix="Steps: [{}]".format(self.num_steps)) data_iter = DALIGenericIterator(self.training_loader, ['data'], auto_reset=True) end = time.time() for i in range(self.start_steps, self.num_steps): # measure output loading time data_time.update(time.time() - end) try: images = next(data_iter)[0]['data'] except StopIteration: data_iter.reset() images = next(data_iter)[0]['data'] images = images.to('cuda') b, d, _, _, c = images.size() images = rearrange(images, 'b d h w c -> (b d) c h w') images = self.normalize(images.float() / 255.) images = rearrange(images, '(b d) c h w -> b (d c) h w', b=b, d=d, c=c) self.optimizer.zero_grad() vq_loss, images_recon, perplexity = self.model(images) recon_error = F.mse_loss(images_recon, images) loss = recon_error + vq_loss loss.backward() self.optimizer.step() meter_loss_constr.update(recon_error.item(), 1) meter_loss_perp.update(perplexity.item(), 1) meter_loss.update(loss.item(), 1) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % 20 == 0: progress.display(i) if i % 1000 == 0: print('saving ...') save_checkpoint( self.folder_name, { 'steps': i, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict() }, 'checkpoint%s.pth.tar' % i) self.scheduler.step() images, images_recon = map( lambda t: rearrange( t, 'b (d c) h w -> b d c h w', b=b, d=d, c=c), [images, images_recon]) images_orig, images_recs = train_visualize( unnormalize=self.unnormalize, images=images[0, :self.n_images_save], n_images=self.n_images_save, image_recs=images_recon[0, :self.n_images_save]) save_images(file_name=os.path.join(self.path_img_orig, f'image_{i}.png'), image=images_orig) save_images(file_name=os.path.join(self.path_img_recs, f'image_{i}.png'), image=images_recs) if self.run_wandb: logs = { 'iter': i, 'loss_recs': meter_loss_constr.val, 'loss': meter_loss.val, 'lr': self.scheduler.get_last_lr()[0] } self.run_wandb.log(logs) print('saving ...') save_checkpoint( self.folder_name, { 'steps': self.num_steps, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict(), }, 'checkpoint%s.pth.tar' % self.num_steps)
def train(self): self.model.train() batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') meter_loss = AverageMeter('Loss', ':.4e') meter_loss_constr = AverageMeter('Constr', ':6.2f') meter_loss_perp = AverageMeter('Perplexity', ':6.2f') progress = ProgressMeter( len(self.training_loader), [batch_time, data_time, meter_loss, meter_loss_constr, meter_loss_perp], prefix="Steps: [{}]".format(self.num_steps)) data_iter = iter(self.training_loader) end = time.time() for i in range(self.start_steps, self.num_steps): # measure output loading time data_time.update(time.time() - end) try: images = next(data_iter) except StopIteration: data_iter = iter(self.training_loader) images = next(data_iter) images = images.to('cuda') self.optimizer.zero_grad() vq_loss, images_recon, perplexity = self.model(images) recon_error = F.mse_loss(images_recon, images) loss = recon_error + vq_loss loss.backward() self.optimizer.step() meter_loss_constr.update(recon_error.item(), 1) meter_loss_perp.update(perplexity.item(), 1) meter_loss.update(loss.item(), 1) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % 20 == 0: progress.display(i) if i % 1000 == 0: print('saving ...') save_checkpoint(self.folder_name, { 'steps': i, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict() }, 'checkpoint%s.pth.tar' % i) self.scheduler.step() images_orig, images_recs = train_visualize( unnormalize=self.unnormalize, images=images[:self.n_images_save], n_images=self.n_images_save, image_recs=images_recon[:self.n_images_save]) save_images(file_name=os.path.join(self.path_img_orig, f'image_{i}.png'), image=images_orig) save_images(file_name=os.path.join(self.path_img_recs, f'image_{i}.png'), image=images_recs) if self.run_wandb: logs = { 'iter': i, 'loss_recs': meter_loss_constr.val, 'loss': meter_loss.val, 'lr': self.scheduler.get_last_lr()[0] } self.run_wandb.log(logs) print('saving ...') save_checkpoint(self.folder_name, { 'steps': self.num_steps, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict(), }, 'checkpoint%s.pth.tar' % self.num_steps)
def main(args) : """Main Function : Data Loading -> Model Building -> Set Optimization -> Training""" # Setting Important Arguments # args.cuda = True args.epochs = 200 args.lr = 1e-5 args.batch_size = 8 # Setting Important Path # train_data_root = 'D:\data\DALE/TRAIN/' model_save_root_dir = 'D:\Pytorch_code\DALE/checkpoint/DALE_VAN/' model_root = '../checkpoint/DALE/' # Setting Important Traning Variable # VISUALIZATION_STEP = 10 SAVE_STEP = 1 print("DALE => Data Loading") train_data = dataset_DALE.DALETrain(train_data_root, args) loader_train = DataLoader(train_data, batch_size=args.batch_size, shuffle=True) print("DALE => Model Building") VisualAttentionNet = VisualAttentionNetwork.VisualAttentionNetwork() print("DALE => Set Optimization") optG = torch.optim.Adam(list(VisualAttentionNet.parameters()), lr=args.lr, betas=(0.5, 0.999)) scheduler = lr_scheduler.ExponentialLR(optG, gamma=0.99) print("DALE => Setting GPU") if args.cuda: print("DALE => Use GPU") VisualAttentionNet = VisualAttentionNet.cuda() print("DALE => Training") loss_step = 0 for epoch in range(1, args.epochs): VisualAttentionNet.train() for itr, data in enumerate(loader_train): low_light_img, ground_truth_img, gt_Attention_img, file_name = data[0], data[1], data[2], data[3] if args.cuda: low_light_img = low_light_img.cuda() ground_truth_img = ground_truth_img.cuda() gt_Attention_img = gt_Attention_img.cuda() optG.zero_grad() attention_result = VisualAttentionNet(low_light_img) mse_loss = L1_loss(attention_result, gt_Attention_img) p_loss = Perceptual_loss(attention_result, gt_Attention_img) * 10 total_loss = p_loss + mse_loss total_loss.backward() optG.step() if epoch > 100 and itr==0: scheduler.step() print(scheduler.get_last_lr()) if itr != 0 and itr % VISUALIZATION_STEP == 0: print("Epoch[{}/{}]({}/{}): " "mse_loss : {:.6f}, " "p_loss : {:.6f}"\ .format(epoch, args.epochs, itr, len(loader_train), mse_loss, p_loss)) # VISDOM LOSS GRAPH # loss_dict = { 'mse_loss': mse_loss.item(), 'p_loss': p_loss.item(), } visdom_loss(visdom, loss_step, loss_dict) # VISDOM VISUALIZATION # -> tensor to numpy => list ('title_name', img_tensor) with torch.no_grad(): val_image = Image.open('../validation/15.jpg') transform = transforms.Compose([ transforms.ToTensor(), ]) val_image = transform((val_image)).unsqueeze(0) val_image = val_image.cuda() val_attention = VisualAttentionNet.eval()(val_image) img_list = OrderedDict( [('input', low_light_img), ('attention_output', attention_result), ('gt_Attention_img', gt_Attention_img), ('batch_sum', attention_result+low_light_img), ('ground_truth', ground_truth_img), ('val_attention', val_attention), ('val_sum', val_image+val_attention) ]) visdom_image(img_dict=img_list, window=10) loss_step = loss_step + 1 print("DALE => Testing") if epoch % SAVE_STEP == 0: train_utils.save_checkpoint(VisualAttentionNet, epoch, model_save_root_dir)
def main(args): args.cuda = True args.epochs = 200 args.lr = 1e-5 args.batch_size = 4 # Setting Important Path # train_data_root = 'D:\data\DALE/TRAIN/' model_save_root_dir = '../checkpoint/DALE/' model_root = '../checkpoint/' # Setting Important Traning Variable # VISUALIZATION_STEP = 50 SAVE_STEP = 1 print("DALE => Data Loading") train_data = dataset_DALE.DALETrain(train_data_root, args) loader_train = DataLoader(train_data, batch_size=args.batch_size, shuffle=True) print("DALE => Model Building") VisualAttentionNet = VisualAttentionNetwork.VisualAttentionNetwork() state_dict = torch.load(model_root + 'VAN.pth') VisualAttentionNet.load_state_dict(state_dict) EnhanceNet = EnhancementNet.EnhancementNet() print("DALE => Set Optimization") optG = torch.optim.Adam(list(EnhanceNet.parameters()), lr=args.lr, betas=(0.5, 0.999)) scheduler = lr_scheduler.ExponentialLR(optG, gamma=0.99) model_EnhanceNet_parameters = filter(lambda p: p.requires_grad, EnhanceNet.parameters()) params1 = sum([np.prod(p.size()) for p in model_EnhanceNet_parameters]) print("Parameters | ", params1) print("DALE => Setting GPU") if args.cuda: print("DALE => Use GPU") VisualAttentionNet = VisualAttentionNet.cuda() EnhanceNet = EnhanceNet.cuda() print("DALE => Training") loss_step = 0 for epoch in range(1, args.epochs): EnhanceNet.train() for itr, data in enumerate(loader_train): low_light_img, ground_truth_img, gt_Attention_img, file_name = data[ 0], data[1], data[2], data[3] if args.cuda: low_light_img = low_light_img.cuda() ground_truth_img = ground_truth_img.cuda() gt_Attention_img = gt_Attention_img.cuda() optG.zero_grad() attention_result = VisualAttentionNet(low_light_img) enhance_result = EnhanceNet(low_light_img, attention_result.detach()) mse_loss = L2_loss(enhance_result, ground_truth_img) p_loss = Perceptual_loss(enhance_result, ground_truth_img) * 50 tv_loss = TvLoss(enhance_result) * 20 total_loss = p_loss + mse_loss + tv_loss total_loss.backward() optG.step() if epoch > 100 and itr == 0: scheduler.step() print(scheduler.get_last_lr()) if itr != 0 and itr % VISUALIZATION_STEP == 0: print("Epoch[{}/{}]({}/{}): " "mse_loss : {:.6f}, " "tv_loss : {:.6f}, " "p_loss : {:.6f}"\ .format(epoch, args.epochs, itr, len(loader_train), mse_loss, tv_loss, p_loss)) # VISDOM LOSS GRAPH # loss_dict = { 'mse_loss': mse_loss.item(), 'tv_loss': tv_loss.item(), 'p_loss': p_loss.item(), } visdom_loss(visdom, loss_step, loss_dict) # VISDOM VISUALIZATION # -> tensor to numpy => list ('title_name', img_tensor) with torch.no_grad(): val_image = Image.open('../validation/15.jpg') transform = transforms.Compose([ transforms.ToTensor(), # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) val_image = transform((val_image)).unsqueeze(0) val_image = val_image.cuda() val_attention = VisualAttentionNet.eval()(val_image) val_result = EnhanceNet.eval()(val_image, val_attention) img_list = OrderedDict([('input', low_light_img), ('output', enhance_result), ('attention_output', attention_result), ('gt_Attention_img', gt_Attention_img), ('batch_sum', attention_result + low_light_img), ('ground_truth', ground_truth_img), ('val_result', val_result)]) visdom_image(img_dict=img_list, window=10) loss_step = loss_step + 1 print("DALE => Testing") if epoch % SAVE_STEP == 0: train_utils.save_checkpoint(EnhanceNet, epoch, model_save_root_dir)
def main(args): """Main Function : Data Loading -> Model Building -> Set Optimization -> Training""" # Setting Important Arguments # args.cuda = True args.epochs = 200 args.lr = 1e-5 args.batch_size = 5 # Setting Important Path # train_data_root = 'D:\data\DALE/' model_save_root_dir = 'D:\Pytorch_code\DALE/checkpoint/' model_root = '../checkpoint/DALEGAN/' # Setting Important Traning Variable # VISUALIZATION_STEP = 50 SAVE_STEP = 1 print("DALE => Data Loading") train_data = dataset_DALE.DALETrainGlobal(train_data_root, args) loader_train = DataLoader(train_data, batch_size=args.batch_size, shuffle=True) print("DALE => Model Building") VAN = VisualAttentionNetwork.AttentionNet2() state_dict1 = torch.load(model_root + 'visual_attention_network_model.pth') VAN.load_state_dict(state_dict1) EnhanceNetG = EnhancementNet.EnhancementNet() EnhanceNetD = EnhancementNet.Discriminator() state_dict2 = torch.load(model_root + 'enhance_GAN.pth') EnhanceNetG.load_state_dict(state_dict2) EnhancementNet_parameters = filter(lambda p: p.requires_grad, EnhanceNetG.parameters()) params1 = sum([np.prod(p.size()) for p in EnhancementNet_parameters]) print("Parameters | Discriminator ", params1) discriminator_parameters = filter(lambda p: p.requires_grad, EnhanceNetD.parameters()) params = sum([np.prod(p.size()) for p in discriminator_parameters]) print("Parameters | Discriminator ", params) print("DALE => Set Optimization") optG = torch.optim.Adam(list(EnhanceNetG.parameters()), lr=args.lr, betas=(0.5, 0.999)) optD = torch.optim.Adam(list(EnhanceNetD.parameters()), lr=args.lr, betas=(0.5, 0.999), weight_decay=0) print("DALE => Setting GPU") if args.cuda: print("DALE => Use GPU") VAN = VAN.cuda() EnhanceNetG = EnhanceNetG.cuda() EnhanceNetD = EnhanceNetD.cuda() print("DALE => Training") loss_step = 0 for epoch in range(args.epochs): EnhanceNetG.train() EnhanceNetD.train() for itr, data in enumerate(loader_train): low_light_img, ground_truth_img, gt_Attention_img, file_name = data[ 0], data[1], data[2], data[3] if args.cuda: low_light_img = low_light_img.cuda() ground_truth_img = ground_truth_img.cuda() gt_Attention_img = gt_Attention_img.cuda() optD.zero_grad() attention_result = VAN(low_light_img) enhance_result = EnhanceNetG(low_light_img, attention_result).detach() loss_D = -torch.mean(EnhanceNetD(ground_truth_img)) \ + torch.mean(EnhanceNetD(enhance_result)) loss_D.backward() optD.step() for p in EnhanceNetD.parameters(): p.data.clamp_(-0.01, 0.01) if itr % 5 == 0: optG.zero_grad() enhance_result = EnhanceNetG(low_light_img, attention_result) loss_G = -torch.mean(EnhanceNetG(enhance_result)) * 0.5 e_loss = L2_loss(enhance_result, ground_truth_img) p_loss = Perceptual_loss(enhance_result, ground_truth_img) * 10 tv_loss = TvLoss(enhance_result) * 5 total_loss = p_loss + e_loss + tv_loss + loss_G total_loss.backward() optG.step() if itr != 0 and itr % VISUALIZATION_STEP == 0: print("Epoch[{}/{}]({}/{}): " "e_loss : {:.6f}, " "tv_loss : {:.6f}, " "p_loss : {:.6f}"\ .format(epoch, args.epochs, itr, len(loader_train), e_loss,tv_loss, p_loss)) # VISDOM LOSS GRAPH # loss_dict = { 'e_loss': e_loss.item(), 'tv_loss': tv_loss.item(), 'p_loss': p_loss.item(), 'g_loss': loss_G.item(), 'd_loss': loss_D.item() # 'recon_loss': recon_loss.item() } visdom_loss(visdom, loss_step, loss_dict) # VISDOM VISUALIZATION # -> tensor to numpy => list ('title_name', img_tensor) with torch.no_grad(): val_image = Image.open('../validation/15.jpg') transform = transforms.Compose([ transforms.ToTensor(), ]) val_image = transform((val_image)).unsqueeze(0) val_image = val_image.cuda() val_attention = VAN.eval()(val_image) val_result = EnhanceNetG.eval()(val_image, val_attention) img_list = OrderedDict([('input', low_light_img), ('output', enhance_result), ('attention_output', attention_result), ('gt_Attention_img', gt_Attention_img), ('ground_truth', ground_truth_img), ('val_result', val_result), ('val_sum', val_attention + val_image) ]) visdom_image(img_dict=img_list, window=10) loss_step = loss_step + 1 print("DALE => Testing") if epoch % SAVE_STEP == 0: train_utils.save_checkpoint(EnhanceNetG, epoch, model_save_root_dir + 'DALEGAN/') train_utils.save_checkpoint( EnhanceNetD, epoch, model_save_root_dir + 'DALE_Discriminator/')