def test(self, testloader, cur_epoch=-1): loss, top1, top5 = AverageTracker(), AverageTracker(), AverageTracker() # Set the model to be in testing mode (for dropout and batchnorm) self.model.eval() for data, target in testloader: if self.args.cuda: data, target = data.cuda(), target.cuda() data_var, target_var = Variable(data, volatile=True), Variable( target, volatile=True) # Forward pass output = self.model(data_var) cur_loss = self.loss(output, target_var) # Top-1 and Top-5 Accuracy Calculation cur_acc1, cur_acc5 = self.compute_accuracy(output.data, target, topk=(1, 5)) loss.update(cur_loss.data[0]) top1.update(cur_acc1[0]) top5.update(cur_acc5[0]) if cur_epoch != -1: # Summary Writing self.summary_writer.add_scalar("test-loss", loss.avg, cur_epoch) self.summary_writer.add_scalar("test-top-1-acc", top1.avg, cur_epoch) self.summary_writer.add_scalar("test-top-5-acc", top5.avg, cur_epoch) print("Test Results" + " | " + "loss: " + str(loss.avg) + " - acc-top1: " + str(top1.avg)[:7] + "- acc-top5: " + str(top5.avg)[:7])
def train(self): for cur_epoch in range(self.start_epoch, self.args.num_epochs): # Initialize tqdm tqdm_batch = tqdm(self.trainloader, desc="Epoch-" + str(cur_epoch) + "-") # Learning rate adjustment self.adjust_learning_rate(self.optimizer, cur_epoch) # Meters for tracking the average values loss, top1, top5 = AverageTracker(), AverageTracker( ), AverageTracker() # Set the model to be in training mode (for dropout and batchnorm) self.model.train() for data, target in tqdm_batch: if self.args.cuda: data, target = data.cuda(), target.cuda() data_var, target_var = Variable(data), Variable(target) # Forward pass output = self.model(data_var) cur_loss = self.loss(output, target_var) # Optimization step self.optimizer.zero_grad() cur_loss.backward() self.optimizer.step() # Top-1 and Top-5 Accuracy Calculation cur_acc1, cur_acc5 = self.compute_accuracy(output.data, target, topk=(1, 5)) loss.update(cur_loss.data[0]) top1.update(cur_acc1[0]) top5.update(cur_acc5[0]) # Summary Writing self.summary_writer.add_scalar("epoch-loss", loss.avg, cur_epoch) self.summary_writer.add_scalar("epoch-top-1-acc", top1.avg, cur_epoch) self.summary_writer.add_scalar("epoch-top-5-acc", top5.avg, cur_epoch) # Print in console tqdm_batch.close() print("Epoch-" + str(cur_epoch) + " | " + "loss: " + str(loss.avg) + " - acc-top1: " + str(top1.avg)[:7] + "- acc-top5: " + str(top5.avg)[:7]) # Evaluate on Validation Set if cur_epoch % self.args.test_every == 0 and self.valloader: self.test(self.valloader, cur_epoch) # Checkpointing is_best = top1.avg > self.best_top1 self.best_top1 = max(top1.avg, self.best_top1) self.save_checkpoint( { 'epoch': cur_epoch + 1, 'state_dict': self.model.state_dict(), 'best_top1': self.best_top1, 'optimizer': self.optimizer.state_dict(), }, is_best)
def train(): """ Introduction ------------ 训练Retinanet模型 """ train_transform = Augmentation(size=config.image_size) # train_dataset = COCODataset(config.coco_train_dir, config.coco_train_annaFile, config.coco_label_file, training = True, transform = train_transform) from VOCDataset import build_vocDataset train_dataset = build_vocDataset(config.voc_root) train_dataloader = DataLoader(train_dataset, batch_size=config.train_batch, shuffle=True, num_workers=2, collate_fn=train_dataset.collate_fn) print("training on {} samples".format(train_dataset.__len__())) net = RetinaNet(config.num_classes, pre_train_path=config.resnet50_path) net.cuda() optimizer = optim.SGD(net.parameters(), lr=config.learning_rate, momentum=0.9, weight_decay=1e-4) criterion = MultiBoxLoss(alpha=config.focal_alpha, gamma=config.focal_gamma, num_classes=config.num_classes) anchors = Anchor(config.anchor_areas, config.aspect_ratio, config.scale_ratios) anchor_boxes = anchors(input_size=config.image_size) for epoch in range(config.Epochs): batch_time, loc_losses, conf_losses = AverageTracker(), AverageTracker( ), AverageTracker() net.train() net.freeze_bn() end = time.time() for index, (image, gt_boxes, labels) in enumerate(train_dataloader): loc_targets, cls_targets = [], [] image = image.cuda() loc_preds, cls_preds = net(image) batch_num = image.shape[0] for idx in range(batch_num): gt_box = gt_boxes[index] label = labels[index] loc_target, cls_target = encode(anchor_boxes, gt_box, label) loc_targets.append(loc_target) cls_targets.append(cls_target) loc_targets = torch.stack(loc_targets).cuda() cls_targets = torch.stack(cls_targets).cuda() loc_loss, cls_loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets) loss = loc_loss + cls_loss optimizer.zero_grad() loss.backward() optimizer.step() loc_losses.update(loc_loss.item(), image.size(0)) conf_losses.update(cls_loss.item(), image.size(0)) batch_time.update(time.time() - end) end = time.time() if idx % config.print_freq == 0: print( 'Epoch: {}/{} Batch: {}/{} loc Loss: {:.4f} {:.4f} conf loss: {:.4f} {:.4f} Time: {:.4f} {:.4f}' .format(epoch, config.Epochs, idx, len(train_dataloader), loc_losses.val, loc_losses.avg, conf_losses.val, conf_losses.avg, batch_time.val, batch_time.avg)) if epoch % config.save_freq == 0: print('save model') torch.save( net.state_dict(), config.model_dir + 'train_model_epoch{}.pth'.format(epoch + 1))
def train(self): all_train_iter_total_loss = [] all_train_iter_corr_loss = [] all_train_iter_recover_loss = [] all_train_iter_change_loss = [] all_train_iter_gan_loss_gen = [] all_train_iter_gan_loss_dis = [] all_val_epo_iou = [] all_val_epo_acc = [] iter_num = [0] epoch_num = [] num_batches = len(self.train_dataloader) for epoch_i in range(self.start_epoch + 1, self.n_epoch): iter_total_loss = AverageTracker() iter_corr_loss = AverageTracker() iter_recover_loss = AverageTracker() iter_change_loss = AverageTracker() iter_gan_loss_gen = AverageTracker() iter_gan_loss_dis = AverageTracker() batch_time = AverageTracker() tic = time.time() # train self.OldLabel_generator.train() self.Image_generator.train() self.discriminator.train() for i, meta in enumerate(self.train_dataloader): image, old_label, new_label = meta[0].cuda(), meta[1].cuda( ), meta[2].cuda() recover_pred, feats = self.OldLabel_generator( label2onehot(old_label, self.cfg.DATASET.N_CLASS)) corr_pred = self.Image_generator(image, feats) # ------------------- # Train Discriminator # ------------------- self.discriminator.set_requires_grad(True) self.optimizer_D.zero_grad() fake_sample = torch.cat((image, corr_pred), 1).detach() real_sample = torch.cat( (image, label2onehot(new_label, cfg.DATASET.N_CLASS)), 1) score_fake_d = self.discriminator(fake_sample) score_real = self.discriminator(real_sample) gan_loss_dis = self.criterion_D(pred_score=score_fake_d, real_score=score_real) gan_loss_dis.backward() self.optimizer_D.step() self.scheduler_D.step() # --------------- # Train Generator # --------------- self.discriminator.set_requires_grad(False) self.optimizer_G.zero_grad() score_fake = self.discriminator( torch.cat((image, corr_pred), 1)) total_loss, corr_loss, recover_loss, change_loss, gan_loss_gen = self.criterion_G( corr_pred, recover_pred, score_fake, old_label, new_label) total_loss.backward() self.optimizer_G.step() self.scheduler_G.step() iter_total_loss.update(total_loss.item()) iter_corr_loss.update(corr_loss.item()) iter_recover_loss.update(recover_loss.item()) iter_change_loss.update(change_loss.item()) iter_gan_loss_gen.update(gan_loss_gen.item()) iter_gan_loss_dis.update(gan_loss_dis.item()) batch_time.update(time.time() - tic) tic = time.time() log = '{}: Epoch: [{}][{}/{}], Time: {:.2f}, ' \ 'Total Loss: {:.6f}, Corr Loss: {:.6f}, Recover Loss: {:.6f}, Change Loss: {:.6f}, GAN_G Loss: {:.6f}, GAN_D Loss: {:.6f}'.format( datetime.now(), epoch_i, i, num_batches, batch_time.avg, total_loss.item(), corr_loss.item(), recover_loss.item(), change_loss.item(), gan_loss_gen.item(), gan_loss_dis.item()) print(log) if (i + 1) % 10 == 0: all_train_iter_total_loss.append(iter_total_loss.avg) all_train_iter_corr_loss.append(iter_corr_loss.avg) all_train_iter_recover_loss.append(iter_recover_loss.avg) all_train_iter_change_loss.append(iter_change_loss.avg) all_train_iter_gan_loss_gen.append(iter_gan_loss_gen.avg) all_train_iter_gan_loss_dis.append(iter_gan_loss_dis.avg) iter_total_loss.reset() iter_corr_loss.reset() iter_recover_loss.reset() iter_change_loss.reset() iter_gan_loss_gen.reset() iter_gan_loss_dis.reset() vis.line(X=np.column_stack( np.repeat(np.expand_dims(iter_num, 0), 6, axis=0)), Y=np.column_stack((all_train_iter_total_loss, all_train_iter_corr_loss, all_train_iter_recover_loss, all_train_iter_change_loss, all_train_iter_gan_loss_gen, all_train_iter_gan_loss_dis)), opts={ 'legend': [ 'total_loss', 'corr_loss', 'recover_loss', 'change_loss', 'gan_loss_gen', 'gan_loss_dis' ], 'linecolor': np.array([[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [0, 255, 255], [255, 0, 255]]), 'title': 'Train loss of generator and discriminator' }, win='Train loss of generator and discriminator') iter_num.append(iter_num[-1] + 1) # eval self.OldLabel_generator.eval() self.Image_generator.eval() self.discriminator.eval() with torch.no_grad(): for j, meta in enumerate(self.valid_dataloader): image, old_label, new_label = meta[0].cuda(), meta[1].cuda( ), meta[2].cuda() recover_pred, feats = self.OldLabel_generator( label2onehot(old_label, self.cfg.DATASET.N_CLASS)) corr_pred = self.Image_generator(image, feats) preds = np.argmax(corr_pred.cpu().detach().numpy().copy(), axis=1) target = new_label.cpu().detach().numpy().copy() self.running_metrics.update(target, preds) if j == 0: color_map1 = gen_color_map(preds[0, :]).astype( np.uint8) color_map2 = gen_color_map(preds[1, :]).astype( np.uint8) color_map = cv2.hconcat([color_map1, color_map2]) cv2.imwrite( os.path.join( self.val_outdir, '{}epoch*{}*{}.png'.format( epoch_i, meta[3][0], meta[3][1])), color_map) score = self.running_metrics.get_scores() oa = score['Overall Acc: \t'] precision = score['Precision: \t'][1] recall = score['Recall: \t'][1] iou = score['Class IoU: \t'][1] miou = score['Mean IoU: \t'] self.running_metrics.reset() epoch_num.append(epoch_i) all_val_epo_acc.append(oa) all_val_epo_iou.append(miou) vis.line(X=np.column_stack( np.repeat(np.expand_dims(epoch_num, 0), 2, axis=0)), Y=np.column_stack((all_val_epo_acc, all_val_epo_iou)), opts={ 'legend': ['val epoch Overall Acc', 'val epoch Mean IoU'], 'linecolor': np.array([[255, 0, 0], [0, 255, 0]]), 'title': 'Validate Accuracy and IoU' }, win='validate Accuracy and IoU') log = '{}: Epoch Val: [{}], ACC: {:.2f}, Recall: {:.2f}, mIoU: {:.4f}' \ .format(datetime.now(), epoch_i, oa, recall, miou) self.logger.info(log) state = { 'epoch': epoch_i, "acc": oa, "recall": recall, "iou": miou, 'model_G_N': self.OldLabel_generator.state_dict(), 'model_G_I': self.Image_generator.state_dict(), 'model_D': self.discriminator.state_dict(), 'optimizer_G': self.optimizer_G.state_dict(), 'optimizer_D': self.optimizer_D.state_dict() } save_path = os.path.join(self.cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format(epoch_i)) torch.save(state, save_path)
def train(self): all_train_iter_total_loss = [] all_val_epo_iou = [] all_val_epo_acc = [] iter_num = [0] epoch_num = [] num_batches = len(self.train_dataloader) for epoch_i in range(self.start_epoch + 1, self.n_epoch): iter_total_loss = AverageTracker() batch_time = AverageTracker() tic = time.time() # train self.Image_generator.train() for i, meta in enumerate(self.train_dataloader): new_image, new_label = meta[0].cuda(), meta[1].cuda() infer_pred = self.Image_generator(new_image) # --------------- # Train Generator # --------------- self.optimizer.zero_grad() total_loss = self.criterion(infer_pred, new_label) total_loss.backward() self.optimizer.step() self.scheduler.step() iter_total_loss.update(total_loss.item()) batch_time.update(time.time() - tic) tic = time.time() log = '{}: Epoch: [{}][{}/{}], Time: {:.2f}, Generator Total Loss: {:.6f}'.format( datetime.now(), epoch_i, i, num_batches, batch_time.avg, total_loss.item()) print(log) if (i+1) % 10 == 0: all_train_iter_total_loss.append(iter_total_loss.avg) iter_total_loss.reset() vis.line( X=iter_num, Y=all_train_iter_total_loss, opts={'legend': ['total_loss'], 'linecolor': np.array([[255, 0, 0]]), 'title': 'Train loss of generator'}, win='Train loss of generator' ) iter_num.append(iter_num[-1] + 1) # eval self.Image_generator.eval() with torch.no_grad(): for j, meta in enumerate(self.valid_dataloader): new_image, new_label = meta[0].cuda(), meta[1].cuda() infer_pred = self.Image_generator(new_image) preds = np.argmax(infer_pred.cpu().detach().numpy().copy(), axis=1) target = new_label.cpu().detach().numpy().copy() self.running_metrics.update(target, preds) if j == 0: color_map1 = gen_color_map(preds[0, :]).astype(np.uint8) color_map2 = gen_color_map(preds[1, :]).astype(np.uint8) color_map = cv2.hconcat([color_map1, color_map2]) cv2.imwrite(os.path.join(self.val_outdir, '{}epoch*{}*{}.png' .format(epoch_i, meta[2][0], meta[2][1])), color_map) score = self.running_metrics.get_scores() oa = score['Overall Acc: \t'] precision = score['Precision: \t'][1] recall = score['Recall: \t'][1] iou = score['Class IoU: \t'][1] miou = score['Mean IoU: \t'] self.running_metrics.reset() epoch_num.append(epoch_i) all_val_epo_acc.append(oa) all_val_epo_iou.append(miou) vis.line( X=np.column_stack(np.repeat(np.expand_dims(epoch_num, 0), 2, axis=0)), Y=np.column_stack(( all_val_epo_acc, all_val_epo_iou)), opts={ 'legend': ['val epoch Overall Acc', 'val epoch Mean IoU'], 'linecolor': np.array( [[255, 0, 0], [0, 255, 0]]), 'title': 'Validate Accuracy and IoU' }, win='validate Accuracy and IoU' ) log = '{}: Epoch Val: [{}], ACC: {:.2f}, Recall: {:.2f}, mIoU: {:.4f}' \ .format(datetime.now(), epoch_i, oa, recall, miou) self.logger.info(log) state = {'epoch': epoch_i, "acc": oa, "recall": recall, "iou": miou, 'model': self.Image_generator.state_dict(), 'optimizer': self.optimizer.state_dict(),} save_path = os.path.join(self.cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format(epoch_i)) torch.save(state, save_path)
def __init__(self, *args): super(DeepUNetTrainer, self).__init__(*args) # log file if self.args.train: ctime = time.ctime().split() log_path = './log' if not os.path.exists(log_path): os.mkdir(log_path) log_dir = os.path.join( log_path, '%s_%s_%s_%s' % (ctime[-1], ctime[1], ctime[2], ctime[3])) os.mkdir(log_dir) with open(os.path.join(log_dir, 'arg.txt'), 'w') as f: f.write(str(args)) self.log_file = open(os.path.join(log_dir, 'loss.txt'), 'w') self.save_path = './data/result' if not os.path.exists(self.save_path): os.mkdir(self.save_path) # build model self.generator = DeepUNetPaintGenerator().to(self.device) self.discriminator = PatchGAN(sigmoid=self.args.no_mse).to(self.device) # set optimizers self.optimizers = self._set_optimizers() # set loss functions self.losses = self._set_losses() # set image pooler self.image_pool = ImagePooling(50) # load pretrained model if self.args.pretrainedG != '': if self.args.verbose: print('load pretrained generator...') load_checkpoints(self.args.pretrainedG, self.generator, self.optimizers['G']) if self.args.pretrainedD != '': if self.args.verbose: print('load pretrained discriminator...') load_checkpoints(self.args.pretrainedD, self.discriminator, self.optimizers['D']) if self.device.type == 'cuda': # enable parallel computation self.generator = nn.DataParallel(self.generator) self.discriminator = nn.DataParallel(self.discriminator) # loss values for tracking self.loss_G_gan = AverageTracker('loss_G_gan') self.loss_G_l1 = AverageTracker('loss_G_l1') self.loss_D_real = AverageTracker('loss_D_real') self.loss_D_fake = AverageTracker('loss_D_fake') # image value self.imageA = None self.imageB = None self.fakeB = None
class DeepUNetTrainer(ModelTrainer): def __init__(self, *args): super(DeepUNetTrainer, self).__init__(*args) # log file if self.args.train: ctime = time.ctime().split() log_path = './log' if not os.path.exists(log_path): os.mkdir(log_path) log_dir = os.path.join( log_path, '%s_%s_%s_%s' % (ctime[-1], ctime[1], ctime[2], ctime[3])) os.mkdir(log_dir) with open(os.path.join(log_dir, 'arg.txt'), 'w') as f: f.write(str(args)) self.log_file = open(os.path.join(log_dir, 'loss.txt'), 'w') self.save_path = './data/result' if not os.path.exists(self.save_path): os.mkdir(self.save_path) # build model self.generator = DeepUNetPaintGenerator().to(self.device) self.discriminator = PatchGAN(sigmoid=self.args.no_mse).to(self.device) # set optimizers self.optimizers = self._set_optimizers() # set loss functions self.losses = self._set_losses() # set image pooler self.image_pool = ImagePooling(50) # load pretrained model if self.args.pretrainedG != '': if self.args.verbose: print('load pretrained generator...') load_checkpoints(self.args.pretrainedG, self.generator, self.optimizers['G']) if self.args.pretrainedD != '': if self.args.verbose: print('load pretrained discriminator...') load_checkpoints(self.args.pretrainedD, self.discriminator, self.optimizers['D']) if self.device.type == 'cuda': # enable parallel computation self.generator = nn.DataParallel(self.generator) self.discriminator = nn.DataParallel(self.discriminator) # loss values for tracking self.loss_G_gan = AverageTracker('loss_G_gan') self.loss_G_l1 = AverageTracker('loss_G_l1') self.loss_D_real = AverageTracker('loss_D_real') self.loss_D_fake = AverageTracker('loss_D_fake') # image value self.imageA = None self.imageB = None self.fakeB = None def train(self, last_iteration): """ Run single epoch """ average_trackers = [ self.loss_G_gan, self.loss_D_fake, self.loss_D_real, self.loss_G_l1 ] self.generator.train() self.discriminator.train() for tracker in average_trackers: tracker.initialize() for i, datas in enumerate(self.data_loader, last_iteration): imageA, imageB, colors = datas if self.args.mode == 'B2A': # swap imageA, imageB = imageB, imageA self.imageA = imageA.to(self.device) self.imageB = imageB.to(self.device) colors = colors.to(self.device) # run forward propagation. ignore attention self.fakeB, _ = self.generator( self.imageA, colors, ) self._update_discriminator() self._update_generator() if self.args.verbose and i % self.args.print_every == 0: print('%s = %f, %s = %f, %s = %f, %s = %f' % ( self.loss_D_real.name, self.loss_D_real(), self.loss_D_fake.name, self.loss_D_fake(), self.loss_G_gan.name, self.loss_G_gan(), self.loss_G_l1.name, self.loss_G_l1(), )) self.log_file.write('%f\t%f\t%f\t%f\n' % (self.loss_D_real(), self.loss_D_fake(), self.loss_G_gan(), self.loss_G_l1())) return i def validate(self, dataset, epoch, samples=3): # self.generator.eval() # self.discriminator.eval() length = len(dataset) # sample images idxs_total = [ random.sample(range(0, length - 1), samples * 2) for _ in range(epoch) ] for j, idxs in enumerate(idxs_total): styles = idxs[samples:] targets = idxs[0:samples] result = Image.new( 'RGB', (5 * self.resolution, samples * self.resolution)) toPIL = transforms.ToPILImage() G_loss_gan = [] G_loss_l1 = [] D_loss_real = [] D_loss_fake = [] l1_loss = self.losses['L1'] gan_loss = self.losses['GAN'] for i, (target, style) in enumerate(zip(targets, styles)): sub_result = Image.new('RGB', (5 * self.resolution, self.resolution)) imageA, imageB, _ = dataset[target] styleA, styleB, colors = dataset[style] if self.args.mode == 'B2A': imageA, imageB = imageB, imageA styleA, styleB = styleB, styleA imageA = imageA.unsqueeze(0).to(self.device) imageB = imageB.unsqueeze(0).to(self.device) styleB = styleB.unsqueeze(0).to(self.device) colors = colors.unsqueeze(0).to(self.device) with torch.no_grad(): fakeB, _ = self.generator( imageA, colors, ) fakeAB = torch.cat([imageA, fakeB], 1) realAB = torch.cat([imageA, imageB], 1) G_loss_l1.append(l1_loss(fakeB, imageB).item()) G_loss_gan.append( gan_loss(self.discriminator(fakeAB), True).item()) D_loss_real.append( gan_loss(self.discriminator(realAB), True).item()) D_loss_fake.append( gan_loss(self.discriminator(fakeAB), False).item()) styleB = styleB.squeeze() fakeB = fakeB.squeeze() imageA = imageA.squeeze() imageB = imageB.squeeze() colors = colors.squeeze() imageA = toPIL(re_scale(imageA).detach().cpu()) imageB = toPIL(re_scale(imageB).detach().cpu()) styleB = toPIL(re_scale(styleB).detach().cpu()) fakeB = toPIL(re_scale(fakeB).detach().cpu()) # synthesize top-4 colors color1 = toPIL(re_scale(colors[0:3].detach().cpu())) color2 = toPIL(re_scale(colors[3:6].detach().cpu())) color3 = toPIL(re_scale(colors[6:9].detach().cpu())) color4 = toPIL(re_scale(colors[9:12].detach().cpu())) color1 = color1.rotate(90) color2 = color2.rotate(90) color3 = color3.rotate(90) color4 = color4.rotate(90) color_result = Image.new('RGB', (self.resolution, self.resolution)) color_result.paste( color1.crop((0, 0, self.resolution, self.resolution // 4)), (0, 0)) color_result.paste( color2.crop((0, 0, self.resolution, self.resolution // 4)), (0, self.resolution // 4)) color_result.paste( color3.crop((0, 0, self.resolution, self.resolution // 4)), (0, self.resolution // 4 * 2)) color_result.paste( color4.crop((0, 0, self.resolution, self.resolution // 4)), (0, self.resolution // 4 * 3)) sub_result.paste(imageA, (0, 0)) sub_result.paste(styleB, (self.resolution, 0)) sub_result.paste(fakeB, (2 * self.resolution, 0)) sub_result.paste(imageB, (3 * self.resolution, 0)) sub_result.paste(color_result, (4 * self.resolution, 0)) result.paste(sub_result, (0, 0 + self.resolution * i)) print( 'Validate D_loss_real = %f, D_loss_fake = %f, G_loss_l1 = %f, G_loss_gan = %f' % ( sum(D_loss_real) / samples, sum(D_loss_fake) / samples, sum(G_loss_l1) / samples, sum(G_loss_gan) / samples, )) save_image( result, 'deepunetpaint_%03d_%02d' % (epoch, j), self.save_path, ) def test(self): raise NotImplementedError def save_model(self, name, epoch): save_checkpoints( self.generator, name + 'G', epoch, optimizer=self.optimizers['G'], ) save_checkpoints(self.discriminator, name + 'D', epoch, optimizer=self.optimizers['D']) def _set_optimizers(self): optimG = optim.Adam(self.generator.parameters(), lr=self.args.learning_rate, betas=(self.args.beta1, 0.999)) optimD = optim.Adam(self.discriminator.parameters(), lr=self.args.learning_rate, betas=(self.args.beta1, 0.999)) return {'G': optimG, 'D': optimD} def _set_losses(self): gan_loss = GANLoss(not self.args.no_mse).to(self.device) l1_loss = nn.L1Loss().to(self.device) return {'GAN': gan_loss, 'L1': l1_loss} def _update_generator(self): optimG = self.optimizers['G'] gan_loss = self.losses['GAN'] l1_loss = self.losses['L1'] batch_size = self.imageA.shape[0] optimG.zero_grad() fake_AB = torch.cat([self.imageA, self.fakeB], 1) logit_fake = self.discriminator(fake_AB) loss_G_gan = gan_loss(logit_fake, True) loss_G_l1 = l1_loss(self.fakeB, self.imageB) * self.args.lambd self.loss_G_gan.update(loss_G_gan.item(), batch_size) self.loss_G_l1.update(loss_G_l1.item(), batch_size) loss_G = loss_G_gan + loss_G_l1 loss_G.backward() optimG.step() def _update_discriminator(self): optimD = self.optimizers['D'] gan_loss = self.losses['GAN'] batch_size = self.imageA.shape[0] optimD.zero_grad() # for real image real_AB = torch.cat([self.imageA, self.imageB], 1) logit_real = self.discriminator(real_AB) loss_D_real = gan_loss(logit_real, True) self.loss_D_real.update(loss_D_real.item(), batch_size) # for fake image fake_AB = torch.cat([self.imageA, self.fakeB], 1) fake_AB = self.image_pool(fake_AB) logit_fake = self.discriminator(fake_AB.detach()) loss_D_fake = gan_loss(logit_fake, False) self.loss_D_fake.update(loss_D_fake.item(), batch_size) loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() optimD.step()