class DRCNTrainer(Trainer): def __init__(self, config, training_loader, testing_loader): super(DRCNTrainer, self).__init__() self.config = config self.GPU_IN_USE = torch.cuda.is_available() self.device = torch.device('cuda' if self.GPU_IN_USE else 'cpu') self.model = None self.lr = config.lr self.nEpochs = config.nEpochs self.criterion = None self.optimizer = None self.scheduler = None self.seed = config.seed self.upscale_factor = config.upscale_factor self.training_loader = training_loader self.testing_loader = testing_loader # DRCN setup self.momentum = 0.9 self.weight_decay = 0.0001 self.loss_alpha = 1.0 self.loss_alpha_zero_epoch = 25 self.loss_alpha_decay = self.loss_alpha / self.loss_alpha_zero_epoch self.loss_beta = 0.001 self.num_recursions = 16 def build_model(self): self.model = Net(num_channels=1, base_channel=64, num_recursions=self.num_recursions, device=self.device).to(self.device) self.model.weight_init() self.criterion = nn.MSELoss() torch.manual_seed(self.seed) if self.GPU_IN_USE: torch.cuda.manual_seed(self.seed) cudnn.benchmark = True self.criterion.cuda() # setup optimizer and scheduler param_groups = [{'params': list(self.model.parameters())}] param_groups += [{'params': [self.model.w]}] self.set_optimizer() def img_preprocess(self, data, interpolation='bicubic'): if interpolation == 'bicubic': interpolation = Image.BICUBIC elif interpolation == 'bilinear': interpolation = Image.BILINEAR elif interpolation == 'nearest': interpolation = Image.NEAREST size = list(data.shape) if len(size) == 4: target_height = int(size[2] * self.upscale_factor) target_width = int(size[3] * self.upscale_factor) out_data = torch.FloatTensor(size[0], size[1], target_height, target_width) for i, img in enumerate(data): transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((target_width, target_height), interpolation=interpolation), transforms.ToTensor() ]) out_data[i, :, :, :] = transform(img) return out_data else: target_height = int(size[1] * self.upscale_factor) target_width = int(size[2] * self.upscale_factor) transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((target_width, target_height), interpolation=interpolation), transforms.ToTensor() ]) return transform(data) def train(self): """ data: [torch.cuda.FloatTensor], 4 batches: [64, 64, 64, 8] """ self.model.train() train_loss = 0 for batch_num, (data, target) in enumerate(self.training_loader): data = self.img_preprocess(data) # resize input image size data, target = data.to(self.device), target.to(self.device) target_d, output = self.model(data) # loss1 loss_1 = 0 for d in range(self.num_recursions): loss_1 += (self.criterion(target_d[d], target) / self.num_recursions) # loss2 loss_2 = self.criterion(output, target) # regularization reg_term = 0 for theta in self.model.parameters(): reg_term += torch.mean(torch.sum(theta**2)) # total loss loss = self.loss_alpha * loss_1 + ( 1 - self.loss_alpha) * loss_2 + self.loss_beta * reg_term loss.backward() train_loss += loss.item() self.optimizer.step() total_time = progress_bar( batch_num, len(self.training_loader), 'Loss: %.4f' % (train_loss / (batch_num + 1))) avg_loss = train_loss / len(self.training_loader) return [avg_loss, total_time] def test(self): """ data: [torch.cuda.FloatTensor], 10 batches: [10, 10, 10, 10, 10, 10, 10, 10, 10, 10] """ self.model.eval() avg_psnr = 0 with torch.no_grad(): for batch_num, (data, target) in enumerate(self.testing_loader): data = self.img_preprocess(data) # resize input image size data, target = data.to(self.device), target.to(self.device) _, prediction = self.model(data) mse = self.criterion(prediction, target) psnr = 10 * log10(1 / mse.item()) avg_psnr += psnr total_time = progress_bar( batch_num, len(self.testing_loader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1))) avg_psnr = avg_psnr / len(self.testing_loader) return [avg_psnr, total_time] def run(self): self.build_model() for epoch in range(1, self.nEpochs + 1): print("\n===> Epoch {} starts:".format(epoch)) self.loss_alpha = max(0.0, self.loss_alpha - self.loss_alpha_decay) avg_loss = self.train() avg_psnr = self.test() self.scheduler.step(epoch) self.save_model(epoch=epoch, avg_error=avg_loss, avg_psnr=avg_psnr)
class DRCNTrainer(object): def __init__(self, config, training_loader, testing_loader): self.model = None self.lr = config.lr self.nEpochs = config.nEpochs self.criterion = None self.optimizer = None self.scheduler = None self.GPU_IN_USE = torch.cuda.is_available() self.seed = config.seed self.upscale_factor = config.upscale_factor self.training_loader = training_loader self.testing_loader = testing_loader # DRCN setup self.momentum = 0.9 self.weight_decay = 0.0001 self.loss_alpha = 1.0 self.loss_alpha_zero_epoch = 25 self.loss_alpha_decay = self.loss_alpha / self.loss_alpha_zero_epoch self.loss_beta = 0.001 self.num_recursions = 16 def build_model(self): self.model = Net(num_channels=1, base_channel=64, num_recursions=self.num_recursions) self.model.weight_init() self.criterion = nn.MSELoss() torch.manual_seed(self.seed) if self.GPU_IN_USE: torch.cuda.manual_seed(self.seed) cudnn.benchmark = True self.model.cuda() self.criterion.cuda() # setup optimizer and scheduler param_groups = [{'params': list(self.model.parameters())}] param_groups += [{'params': [self.model.w]}] self.optimizer = optim.Adam(param_groups, lr=self.lr) self.scheduler = optim.lr_scheduler.MultiStepLR( self.optimizer, milestones=[50, 75, 100], gamma=0.5) # lr decay def img_preprocess(self, data, interpolation='bicubic'): if interpolation == 'bicubic': interpolation = Image.BICUBIC elif interpolation == 'bilinear': interpolation = Image.BILINEAR elif interpolation == 'nearest': interpolation = Image.NEAREST size = list(data.shape) if len(size) == 4: target_height = int(size[2] * self.upscale_factor) target_width = int(size[3] * self.upscale_factor) out_data = torch.FloatTensor(size[0], size[1], target_height, target_width) for i, img in enumerate(data): transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((target_width, target_height), interpolation=interpolation), transforms.ToTensor() ]) out_data[i, :, :, :] = transform(img) return out_data else: target_height = int(size[1] * self.upscale_factor) target_width = int(size[2] * self.upscale_factor) transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((target_width, target_height), interpolation=interpolation), transforms.ToTensor() ]) return transform(data) def save(self): model_out_path = "DRCN_model_path.pth" torch.save(self.model, model_out_path) print("Checkpoint saved to {}".format(model_out_path)) def train(self): """ data: [torch.cuda.FloatTensor], 4 batches: [64, 64, 64, 8] """ self.model.train() train_loss = 0 for batch_num, (data, target) in enumerate(self.training_loader): data = self.img_preprocess(data) # resize input image size if self.GPU_IN_USE: data, target = Variable(data).cuda(), Variable(target).cuda() self.optimizer.zero_grad() target_d, output = self.model(data) # loss1 loss_1 = 0 for d in range(self.num_recursions): loss_1 += (self.criterion(target_d[d], target) / self.num_recursions) # loss2 loss_2 = self.criterion(output, target) # regularization reg_term = 0 for theta in self.model.parameters(): reg_term += torch.mean(torch.sum(theta**2)) # total loss loss = self.loss_alpha * loss_1 + ( 1 - self.loss_alpha) * loss_2 + self.loss_beta * reg_term loss.backward() train_loss += loss.data[0] self.optimizer.step() progress_bar(batch_num, len(self.training_loader), 'Loss: %.4f' % (train_loss / (batch_num + 1))) print(" Average Loss: {:.4f}".format(train_loss / len(self.training_loader))) def test(self): """ data: [torch.cuda.FloatTensor], 10 batches: [10, 10, 10, 10, 10, 10, 10, 10, 10, 10] """ self.model.eval() avg_psnr = 0 for batch_num, (data, target) in enumerate(self.testing_loader): data = self.img_preprocess(data) # resize input image size if self.GPU_IN_USE: data, target = Variable(data).cuda(), Variable(target).cuda() _, prediction = self.model(data) mse = self.criterion(prediction, target) psnr = 10 * log10(1 / mse.data[0]) avg_psnr += psnr progress_bar(batch_num, len(self.testing_loader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1))) print(" Average PSNR: {:.4f} dB".format(avg_psnr / len(self.testing_loader))) def validate(self): self.build_model() for epoch in range(1, self.nEpochs + 1): print("\n===> Epoch {} starts:".format(epoch)) self.loss_alpha = max(0.0, self.loss_alpha - self.loss_alpha_decay) self.train() self.test() self.scheduler.step(epoch) if epoch == self.nEpochs: self.save()