class Network(nn.Module): def __init__(self): super(Network, self).__init__() self.feature_extractor = resnet50() # Already pretrained # self.feature_extractor = resnet50(pretrained_path=None) self.selector = Selector() self.dis = Discriminator() self.optmzr_select = Adam(self.selector.parameters(), lr=1e-3) self.optmzr_dis = Adam(self.dis.parameters(), lr=1e-3) def forward(self, anchor: Variable, real_data: Variable, fake_data: Variable): assert len(anchor.size()) == 4 and len(anchor.size()) == 4 fea_anchor = self.feature_extractor(anchor) fea_real = self.feature_extractor(real_data) fea_fake = self.feature_extractor(fake_data) # not train_feature: fea_anchor = fea_anchor.detach() fea_real = fea_real.detach() fea_fake = fea_fake.detach() score_real = self.dis(fea_anchor, fea_real) score_fake = self.dis(fea_anchor, fea_fake) return score_real, score_fake def bp_dis(self, score_real, score_fake): real_label = Variable(torch.normal(torch.ones(score_real.size()), torch.zeros(score_real.size()) + 0.05)).cuda() fake_label = Variable( torch.normal(torch.zeros(score_real.size()), torch.zeros(score_real.size()) + 0.05)).cuda() loss = torch.mean(F.binary_cross_entropy(score_real, real_label, size_average=False) + \ F.binary_cross_entropy(score_fake, fake_label, size_average=False)) # loss = -(torch.mean(torch.log(score_real + 1e-6)) - torch.mean(torch.log(.5 + score_fake / 2 + 1e-6))) self.optmzr_dis.zero_grad() loss.backward() return self.optmzr_dis.step() def bp_select(self, score_fake: Variable, fake_prob): # torch.mean(torch.log(prob) * torch.log(1 - score_fake), 0) n_sample = score_fake.size()[0] self.optmzr_dis.zero_grad() re = (score_fake.data - .5) * 2 torch.log(fake_prob).backward(re / n_sample)
class CoCosModel(BaseModel): @staticmethod def modify_commandline_options(parser, is_train=True): return parser @staticmethod def torch2numpy(x): # from [-1,1] to [0,255] return ((x.detach().cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8) def __name__(self): return 'CoCosModel' def __init__(self, opt): super().__init__(opt) self.w = opt.image_size # make a folder for save images self.image_dir = os.path.join(self.save_dir, 'images') if not os.path.isdir(self.image_dir): os.mkdir(self.image_dir) # initialize networks self.model_names = ['C', 'T'] self.netC = CorrespondenceNet(opt) self.netT = TranslationNet(opt) if opt.isTrain: self.model_names.append('D') self.netD = Discriminator(opt) self.visual_names = ['b_exemplar', 'a', 'b_gen', 'b_gt'] # HPT convention if opt.isTrain: # assign losses self.loss_names = [ 'perc', 'domain', 'feat', 'context', 'reg', 'adv' ] self.visual_names += ['b_warp'] self.criterionFeat = torch.nn.L1Loss() # Both interface for VGG and perceptual loss # call with different mode and layer params self.criterionVGG = VGGLoss(self.device) # Support hinge loss self.criterionAdv = GANLoss(gan_mode=opt.gan_mode).to(self.device) self.criterionDomain = nn.L1Loss() self.criterionReg = torch.nn.L1Loss() # initialize optimizers gen_params = itertools.chain(self.netT.parameters(), self.netC.parameters()) self.optG = torch.optim.Adam(gen_params, lr=opt.lr, betas=(opt.beta1, 0.999)) self.optD = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [self.optG, self.optD] # Finally, load checkpoints and recover schedulers self.setup(opt) torch.autograd.set_detect_anomaly(True) def set_input(self, batch): # expecting 'a' -> 'b_gt', 'a_exemplar' -> 'b_exemplar', ('b_deform') # for human pose transfer, 'b_deform' is already 'b_exemplar' for k, v in batch.items(): setattr(self, k, v.to(self.device)) def forward(self): self.sa, self.sb, self.fb_warp, self.b_warp = self.netC( self.a, self.b_exemplar) # 3*HW*HW self.b_gen = self.netT(self.b_warp) # self.b_gen = self.netT(self.fb_warp) retain original feature or use warped rgb? # TODO: Implement backward warping (maybe we should adjust the input size?) _, _, _, self.b_reg = self.netC( self.a_exemplar, F.interpolate(self.b_warp, (self.w, self.w), mode='bilinear')) #print(self.b_gen.shape, self.b_reg.shape, self.b_gt.shape) def test(self): with torch.no_grad(): _, _, _, self.b_warp = self.netC(self.a, self.b_exemplar) # 3*HW*HW self.b_gen = self.netT(self.b_warp) def backward_G(self): self.optG.zero_grad() # Damn, do we really need 6 losses? # 1. Perc loss(For human pose transfer we abandon it, it's all in the criterion Feat) self.loss_perc = 0 # 2. domain loss self.loss_domain = self.opt.lambda_domain * self.criterionDomain( self.sa, self.sb) # 3. losses for pseudo exemplar pairs self.loss_feat = self.opt.lambda_feat * self.criterionVGG( self.b_gen, self.b_gt, mode='perceptual') # 4. Contextural loss self.loss_context = self.opt.lambda_context * self.criterionVGG( self.b_gen, self.b_exemplar, mode='contextual', layers=[2, 3, 4, 5]) # 5. Reg loss b_exemplar_small = F.interpolate(self.b_exemplar, self.b_reg.size()[2:], mode='bilinear') self.loss_reg = self.opt.lambda_reg * self.criterionReg( self.b_reg, b_exemplar_small) # 6. GAN loss pred_real, pred_fake = self.discriminate(self.b_gt, self.b_gen) self.loss_adv = self.opt.lambda_adv * self.criterionAdv( pred_fake, True, for_discriminator=False) g_loss = self.loss_perc + self.loss_domain + self.loss_feat \ + self.loss_context + self.loss_reg + self.loss_adv g_loss.backward() self.optG.step() def discriminate(self, real, fake): fake_and_real = torch.cat([fake, real], dim=0) discriminator_out = self.netD(fake_and_real) pred_fake, pred_real = self.divide_pred(discriminator_out) return pred_fake, pred_real # Take the prediction of fake and real images from the combined batch def divide_pred(self, pred): # the prediction contains the intermediate outputs of multiscale GAN, # so it's usually a list if isinstance(pred, list): fake = [p[:p.size(0) // 2] for p in pred] real = [p[p.size(0) // 2:] for p in pred] else: fake = pred[:pred.size(0) // 2] real = pred[pred.size(0) // 2:] return fake, real def backward_D(self): self.optD.zero_grad() # test, run under no_grad mode self.test() pred_fake, pred_real = self.discriminate(self.b_gt, self.b_gen) self.d_fake = self.criterionAdv(pred_fake, False, for_discriminator=True) self.d_real = self.criterionAdv(pred_real, True, for_discriminator=True) d_loss = (self.d_fake + self.d_real) / 2 d_loss.backward() self.optD.step() def optimize_parameters(self): # must call self.set_input(data) first self.forward() self.backward_G() self.backward_D() ### Standalone utility functions def log_loss(self, epoch, iter): msg = 'Epoch %d iter %d\n ' % (epoch, iter) for name in self.loss_names: val = getattr(self, 'loss_%s' % name) if isinstance(val, torch.cuda.FloatTensor): val = val.item() msg += '%s: %.4f, ' % (name, val) print(msg) def log_visual(self, epoch, iter): save_path = os.path.join(self.save_image_dir, 'epoch%03d_iter%05d.png' % (epoch, iter)) # warped image is not the same resolution, need scaling self.b_warp = F.interpolate(self.b_warp, (self.w, self.w), mode='bicubic') pack = torch.cat([getattr(self, name) for name in self.visual_names], dim=3)[0] # only save one example cv2.imwrite(save_path, self.torch2numpy(pack)) cv2.imwrite('b_ex' + save_path, self.torch2numpy(self.b_exemplar[0])) def update_learning_rate(self): ''' Update learning rates for all the networks; called at the end of every epoch by train.py ''' for scheduler in self.schedulers: scheduler.step() lr = self.optimizers[0].param_groups[0]['lr'] print('learning rate updated to %.7f' % lr)
class Trainer(object): def __init__(self, config, args): self.args = args self.config = config self.visdom = args.visdom if args.visdom: self.vis = visdom.Visdom(env=os.getcwd().split('/')[-1], port=8888) # Define Dataloader self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader( config) self.target_train_loader, self.target_val_loader, self.target_test_loader, _ = make_target_data_loader( config) # Define network self.model = DeepLab(num_classes=self.nclass, backbone=config.backbone, output_stride=config.out_stride, sync_bn=config.sync_bn, freeze_bn=config.freeze_bn) self.D = Discriminator(num_classes=self.nclass, ndf=16) train_params = [{ 'params': self.model.get_1x_lr_params(), 'lr': config.lr }, { 'params': self.model.get_10x_lr_params(), 'lr': config.lr * config.lr_ratio }] # Define Optimizer self.optimizer = torch.optim.SGD(train_params, momentum=config.momentum, weight_decay=config.weight_decay) self.D_optimizer = torch.optim.Adam(self.D.parameters(), lr=config.lr, betas=(0.9, 0.99)) # Define Criterion # whether to use class balanced weights self.criterion = SegmentationLosses( weight=None, cuda=args.cuda).build_loss(mode=config.loss) self.entropy_mini_loss = MinimizeEntropyLoss() self.bottleneck_loss = BottleneckLoss() self.instance_loss = InstanceLoss() # Define Evaluator self.evaluator = Evaluator(self.nclass) # Define lr scheduler self.scheduler = LR_Scheduler(config.lr_scheduler, config.lr, config.epochs, len(self.train_loader), config.lr_step, config.warmup_epochs) self.summary = TensorboardSummary('./train_log') # labels for adversarial training self.source_label = 0 self.target_label = 1 # Using cuda if args.cuda: self.model = torch.nn.DataParallel(self.model) patch_replication_callback(self.model) # cudnn.benchmark = True self.model = self.model.cuda() self.D = torch.nn.DataParallel(self.D) patch_replication_callback(self.D) self.D = self.D.cuda() self.best_pred_source = 0.0 self.best_pred_target = 0.0 # Resuming checkpoint if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) if args.cuda: self.model.module.load_state_dict(checkpoint) else: self.model.load_state_dict(checkpoint, map_location=torch.device('cpu')) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, args.start_epoch)) def training(self, epoch): train_loss, seg_loss_sum, bn_loss_sum, entropy_loss_sum, adv_loss_sum, d_loss_sum, ins_loss_sum = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 self.model.train() if config.freeze_bn: self.model.module.freeze_bn() tbar = tqdm(self.train_loader) num_img_tr = len(self.train_loader) target_train_iterator = iter(self.target_train_loader) for i, sample in enumerate(tbar): itr = epoch * len(self.train_loader) + i #if self.visdom: # self.vis.line(X=torch.tensor([itr]), Y=torch.tensor([self.optimizer.param_groups[0]['lr']]), # win='lr', opts=dict(title='lr', xlabel='iter', ylabel='lr'), # update='append' if itr>0 else None) self.summary.writer.add_scalar( 'Train/lr', self.optimizer.param_groups[0]['lr'], itr) A_image, A_target = sample['image'], sample['label'] # Get one batch from target domain try: target_sample = next(target_train_iterator) except StopIteration: target_train_iterator = iter(self.target_train_loader) target_sample = next(target_train_iterator) B_image, B_target, B_image_pair = target_sample[ 'image'], target_sample['label'], target_sample['image_pair'] if self.args.cuda: A_image, A_target = A_image.cuda(), A_target.cuda() B_image, B_target, B_image_pair = B_image.cuda( ), B_target.cuda(), B_image_pair.cuda() self.scheduler(self.optimizer, i, epoch, self.best_pred_source, self.best_pred_target, self.config.lr_ratio) self.scheduler(self.D_optimizer, i, epoch, self.best_pred_source, self.best_pred_target, self.config.lr_ratio) A_output, A_feat, A_low_feat = self.model(A_image) B_output, B_feat, B_low_feat = self.model(B_image) #B_output_pair, B_feat_pair, B_low_feat_pair = self.model(B_image_pair) #B_output_pair, B_feat_pair, B_low_feat_pair = flip(B_output_pair, dim=-1), flip(B_feat_pair, dim=-1), flip(B_low_feat_pair, dim=-1) self.optimizer.zero_grad() self.D_optimizer.zero_grad() # Train seg network for param in self.D.parameters(): param.requires_grad = False # Supervised loss seg_loss = self.criterion(A_output, A_target) main_loss = seg_loss # Unsupervised loss #ins_loss = 0.01 * self.instance_loss(B_output, B_output_pair) #main_loss += ins_loss # Train adversarial loss D_out = self.D(prob_2_entropy(F.softmax(B_output))) adv_loss = bce_loss(D_out, self.source_label) main_loss += self.config.lambda_adv * adv_loss main_loss.backward() # Train discriminator for param in self.D.parameters(): param.requires_grad = True A_output_detach = A_output.detach() B_output_detach = B_output.detach() # source D_source = self.D(prob_2_entropy(F.softmax(A_output_detach))) source_loss = bce_loss(D_source, self.source_label) source_loss = source_loss / 2 # target D_target = self.D(prob_2_entropy(F.softmax(B_output_detach))) target_loss = bce_loss(D_target, self.target_label) target_loss = target_loss / 2 d_loss = source_loss + target_loss d_loss.backward() self.optimizer.step() self.D_optimizer.step() seg_loss_sum += seg_loss.item() #ins_loss_sum += ins_loss.item() adv_loss_sum += self.config.lambda_adv * adv_loss.item() d_loss_sum += d_loss.item() #train_loss += seg_loss.item() + self.config.lambda_adv * adv_loss.item() train_loss += seg_loss.item() self.summary.writer.add_scalar('Train/SegLoss', seg_loss.item(), itr) #self.summary.writer.add_scalar('Train/InsLoss', ins_loss.item(), itr) self.summary.writer.add_scalar('Train/AdvLoss', adv_loss.item(), itr) self.summary.writer.add_scalar('Train/DiscriminatorLoss', d_loss.item(), itr) tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) # Show the results of the last iteration #if i == len(self.train_loader)-1: print("Add Train images at epoch" + str(epoch)) self.summary.visualize_image('Train-Source', self.config.dataset, A_image, A_target, A_output, epoch, 5) self.summary.visualize_image('Train-Target', self.config.target, B_image, B_target, B_output, epoch, 5) print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config.batch_size + A_image.data.shape[0])) print('Loss: %.3f' % train_loss) #print('Seg Loss: %.3f' % seg_loss_sum) #print('Ins Loss: %.3f' % ins_loss_sum) #print('BN Loss: %.3f' % bn_loss_sum) #print('Adv Loss: %.3f' % adv_loss_sum) #print('Discriminator Loss: %.3f' % d_loss_sum) #if self.visdom: #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([seg_loss_sum]), win='train_loss', name='Seg_loss', # opts=dict(title='loss', xlabel='epoch', ylabel='loss'), # update='append' if epoch > 0 else None) #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([ins_loss_sum]), win='train_loss', name='Ins_loss', # opts=dict(title='loss', xlabel='epoch', ylabel='loss'), # update='append' if epoch > 0 else None) #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([bn_loss_sum]), win='train_loss', name='BN_loss', # opts=dict(title='loss', xlabel='epoch', ylabel='loss'), # update='append' if epoch > 0 else None) #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([adv_loss_sum]), win='train_loss', name='Adv_loss', # opts=dict(title='loss', xlabel='epoch', ylabel='loss'), # update='append' if epoch > 0 else None) #self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([d_loss_sum]), win='train_loss', name='Dis_loss', # opts=dict(title='loss', xlabel='epoch', ylabel='loss'), # update='append' if epoch > 0 else None) def validation(self, epoch): def get_metrics(tbar, if_source=False): self.evaluator.reset() test_loss = 0.0 #feat_mean, low_feat_mean, feat_var, low_feat_var = 0, 0, 0, 0 #adv_loss = 0.0 for i, sample in enumerate(tbar): image, target = sample['image'], sample['label'] if self.args.cuda: image, target = image.cuda(), target.cuda() with torch.no_grad(): output, low_feat, feat = self.model(image) #low_feat = low_feat.cpu().numpy() #feat = feat.cpu().numpy() #if isinstance(feat, np.ndarray): # feat_mean += feat.mean(axis=0).mean(axis=1).mean(axis=1) # low_feat_mean += low_feat.mean(axis=0).mean(axis=1).mean(axis=1) # feat_var += feat.var(axis=0).var(axis=1).var(axis=1) # low_feat_var += low_feat.var(axis=0).var(axis=1).var(axis=1) #else: # feat_mean = feat.mean(axis=0).mean(axis=1).mean(axis=1) # low_feat_mean = low_feat.mean(axis=0).mean(axis=1).mean(axis=1) # feat_var = feat.var(axis=0).var(axis=1).var(axis=1) # low_feat_var = low_feat.var(axis=0).var(axis=1).var(axis=1) #d_output = self.D(prob_2_entropy(F.softmax(output))) #adv_loss += bce_loss(d_output, self.source_label).item() loss = self.criterion(output, target) test_loss += loss.item() tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) pred = output.data.cpu().numpy() target_ = target.cpu().numpy() pred = np.argmax(pred, axis=1) # Add batch sample into evaluator self.evaluator.add_batch(target_, pred) if if_source: print("Add Validation-Source images at epoch" + str(epoch)) self.summary.visualize_image('Val-Source', self.config.dataset, image, target, output, epoch, 5) else: print("Add Validation-Target images at epoch" + str(epoch)) self.summary.visualize_image('Val-Target', self.config.target, image, target, output, epoch, 5) #feat_mean /= (i+1) #low_feat_mean /= (i+1) #feat_var /= (i+1) #low_feat_var /= (i+1) #adv_loss /= (i+1) # Fast test during the training Acc = self.evaluator.Building_Acc() IoU = self.evaluator.Building_IoU() mIoU = self.evaluator.Mean_Intersection_over_Union() if if_source: print('Validation on source:') else: print('Validation on target:') print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config.batch_size + image.data.shape[0])) print("Acc:{}, IoU:{}, mIoU:{}".format(Acc, IoU, mIoU)) print('Loss: %.3f' % test_loss) if if_source: names = ['source', 'source_acc', 'source_IoU', 'source_mIoU'] self.summary.writer.add_scalar('Val/SourceAcc', Acc, epoch) self.summary.writer.add_scalar('Val/SourceIoU', IoU, epoch) else: names = ['target', 'target_acc', 'target_IoU', 'target_mIoU'] self.summary.writer.add_scalar('Val/TargetAcc', Acc, epoch) self.summary.writer.add_scalar('Val/TargetIoU', IoU, epoch) # Draw Visdom #if if_source: # names = ['source', 'source_acc', 'source_IoU', 'source_mIoU'] #else: # names = ['target', 'target_acc', 'target_IoU', 'target_mIoU'] #if self.visdom: # self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([test_loss]), win='val_loss', name=names[0], # update='append') # self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([adv_loss]), win='val_loss', name='adv_loss', # update='append') # self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([Acc]), win='metrics', name=names[1], # opts=dict(title='metrics', xlabel='epoch', ylabel='performance'), # update='append' if epoch > 0 else None) # self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([IoU]), win='metrics', name=names[2], # update='append') # self.vis.line(X=torch.tensor([epoch]), Y=torch.tensor([mIoU]), win='metrics', name=names[3], # update='append') return Acc, IoU, mIoU self.model.eval() tbar_source = tqdm(self.val_loader, desc='\r') tbar_target = tqdm(self.target_val_loader, desc='\r') s_acc, s_iou, s_miou = get_metrics(tbar_source, True) t_acc, t_iou, t_miou = get_metrics(tbar_target, False) new_pred_source = s_iou new_pred_target = t_iou if new_pred_source > self.best_pred_source or new_pred_target > self.best_pred_target: is_best = True self.best_pred_source = max(new_pred_source, self.best_pred_source) self.best_pred_target = max(new_pred_target, self.best_pred_target) print('Saving state, epoch:', epoch) torch.save( self.model.module.state_dict(), self.args.save_folder + 'models/' + 'epoch' + str(epoch) + '.pth') loss_file = { 's_Acc': s_acc, 's_IoU': s_iou, 's_mIoU': s_miou, 't_Acc': t_acc, 't_IoU': t_iou, 't_mIoU': t_miou } with open( os.path.join(self.args.save_folder, 'eval', 'epoch' + str(epoch) + '.json'), 'w') as f: json.dump(loss_file, f)
def main(): # dataset preparation source_data, target_data, val_data = create_dataset(mode='G2C') source_dataloader = Data.DataLoader(source_data, batch_size=parser.batch_size, shuffle=True, num_workers=parser.num_workers, pin_memory=True) target_dataloader = Data.DataLoader(target_data, batch_size=parser.batch_size, shuffle=True, num_workers=parser.num_workers, pin_memory=True) val_dataloader = Data.DataLoader(val_data, batch_size=parser.batch_size, shuffle=False, num_workers=parser.num_workers, pin_memory=True) source_dataloader_iter = enumerate(source_dataloader) target_dataloader_iter = enumerate(target_dataloader) save_dir = parser.ckpt_dir # create model and optimizer model = create_model(num_classes=parser.num_classes, name='DeepLab') D1 = Discriminator(num_classes=parser.num_classes) D2 = Discriminator(num_classes=parser.num_classes) optimizer_G = create_optimizer(model.get_optim_params(parser), lr=parser.learning_rate, momentum=parser.momentum, weight_decay=parser.weight_decay, name="SGD") optimizer_D1 = create_optimizer(D1.parameters(), lr=LEARNING_RATE_D, name="Adam", betas=BETAS) optimizer_D2 = create_optimizer(D2.parameters(), lr=LEARNING_RATE_D, name="Adam", betas=BETAS) optimizer_G.zero_grad() optimizer_D1.zero_grad() optimizer_D2.zero_grad() start_iter = 1 last_mIoU = 0 if parser.restore: print("loading checkpoint...") checkpoint = torch.load(save_dir) start_iter = checkpoint['iter'] model.load_state_dict(checkpoint['model']) optimizer_G.load_state_dict(checkpoint['optimizer']['G']) optimizer_D1.load_state_dict(checkpoint['optimizer']['D1']) optimizer_D2.load_state_dict(checkpoint['optimizer']['D2']) last_mIoU = checkpoint['best_mIoU'] print("start training...") print("pytorch version: " + TORCH_VERSION + ", cuda version: " + TORCH_CUDA_VERSION + ", cudnn version: " + CUDNN_VERSION) print("available graphical device: " + DEVICE_NAME) os.system("nvidia-smi") discriminator = {'D1': D1, 'D2': D2} optimizer = {'G': optimizer_G, 'D1': optimizer_D1, 'D2': optimizer_D2} best_mIoU, best_iter = train(model, discriminator, optimizer, source_dataloader_iter, target_dataloader_iter, val_dataloader, start_iter, last_mIoU) print("finished training, the best mIoU is: " + str(best_mIoU) + " in iteration " + str(best_iter))
if is_cuda: feature_extractor.cuda() dis.cuda() # input pipeline data_iter = DataProvider(batch_size, is_cuda=is_cuda) # summary writer if log_path: writer = SummaryWriter(log_path, 'comment test') else: writer = None # opt opt_d = Adam(dis.parameters()) opt_fea = Adam(feature_extractor.parameters()) def train_dis(): # label # real_label = Variable(torch.normal(torch.ones(batch_size), torch.zeros(batch_size) + 0.02)).cuda() # fake_label = Variable(torch.normal(torch.zeros(batch_size), torch.zeros(batch_size) + 0.02)).cuda() real_label = Variable(torch.ones(batch_size).cuda()) fake_label = Variable(torch.zeros(batch_size).cuda()) anchor, real_img, wrong_img = data_iter.next() anchor, real_img, wrong_img = Variable(anchor), Variable( real_img), Variable(wrong_img) fea_anc = feature_extractor(anchor) fea_real = feature_extractor(real_img)
class Model(pl.LightningModule): def __init__(self, hparams, device, G_AB, G_BA): super(Model, self).__init__() self.hparams = hparams self.device = device self.input_shape = hparams.input_shape self.learning_rate = hparams.learning_rate self.B1 = hparams.b1 self.B2 = hparams.b1 self.n_epochs = hparams.n_epochs self.start_epoch = hparams.start_epoch self.epoch_decay = hparams.epoch_decay self.batch_size = hparams.batch_size self.lambda_cycle_loss = hparams.lambda_cycle_loss self.lambda_identity_loss = hparams.lambda_identity_loss self.G_AB = G_AB self.G_BA = G_BA self.D_A = Discriminator(self.input_shape) self.D_B = Discriminator(self.input_shape) # Adversarial ground truths self.valid = torch.ones( (self.batch_size, *self.D_A.output_shape)).to(device) self.fake = torch.zeros( (self.batch_size, *self.D_A.output_shape)).to(device) # Losses self.criterion_GAN = torch.nn.MSELoss() self.criterion_cycle = torch.nn.L1Loss() self.criterion_identity = torch.nn.L1Loss() self.fake_A = None self.fake_B = None self.recov_A = None self.recov_B = None def forward(self, real_A, real_B): return self.G_AB(real_A), self.G_BA(real_B) def training_step(self, batch, batch_index, optimizer_index=0): loss = None loss_type = None # Set model input real_A = batch["A"].to(self.device) real_B = batch["B"].to(self.device) # ------------------------------- # Train Generators (G_AB, G_BA) # ------------------------------- if optimizer_index == 0: # Identity loss loss_id_A = self.criterion_identity(self.G_BA(real_A), real_A) loss_id_B = self.criterion_identity(self.G_AB(real_B), real_B) loss_identity = (loss_id_A + loss_id_B) / 2 # GAN loss self.fake_B, self.fake_A = self.forward(real_A, real_B) loss_GAN_AB = self.criterion_GAN(self.D_B(self.fake_B), self.valid) loss_GAN_BA = self.criterion_GAN(self.D_A(self.fake_A), self.valid) loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2 self.recov_B, self.recov_A = self.forward(self.fake_A, self.fake_B) # Cycle loss loss_cycle_A = self.criterion_cycle(self.recov_A, real_A) loss_cycle_B = self.criterion_cycle(self.recov_B, real_B) loss_cycle = (loss_cycle_A + loss_cycle_B) / 2 # Total loss loss = loss_GAN + self.lambda_cycle_loss * loss_cycle + self.lambda_identity_loss * loss_identity loss_type = 'G' # ----------------------- # Train Discriminator A # ----------------------- elif optimizer_index == 1: # Real loss loss_real = self.criterion_GAN(self.D_A(real_A), self.valid) # Fake loss loss_fake = self.criterion_GAN(self.D_A(self.fake_A.detach()), self.fake) # Total loss loss = (loss_real + loss_fake) / 2 loss_type = "D_A" # ----------------------- # Train Discriminator B # ----------------------- elif optimizer_index == 2: # Real loss loss_real = self.criterion_GAN(self.D_B(real_B), self.valid) # Fake loss loss_fake = self.criterion_GAN(self.D_B(self.fake_B.detach()), self.fake) # Total loss loss = (loss_real + loss_fake) / 2 loss_type = "D_B" tqdm_dict = {f"{loss_type}_loss": loss} return OrderedDict({ 'loss': loss, 'progress_bar': tqdm_dict, 'log': tqdm_dict }) def validation_step(self, batch, batch_nb): if batch_nb == 0: self.sample_network_images(batch) loss_data = self.training_step(batch, batch_nb) return { 'val_loss': loss_data['loss'], 'progress_bar': loss_data['progress_bar'], 'log': loss_data['log'] } def validation_end(self, outputs): avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() return {'val_loss': avg_loss} # def test_step(self, batch, batch_nb): # # OPTIONAL # x, y = batch # y_hat = self.forward(x) # return {'test_loss': F.cross_entropy(y_hat, y)} # # def test_end(self, outputs): # # OPTIONAL # avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() # return {'avg_test_loss': avg_loss} def configure_optimizers(self): # Optimizers optimizer_G = torch.optim.Adam(itertools.chain(self.G_AB.parameters(), self.G_BA.parameters()), lr=self.learning_rate, betas=(self.B1, self.B2)) optimizer_D_A = torch.optim.Adam(self.D_A.parameters(), lr=self.learning_rate, betas=(self.B1, self.B2)) optimizer_D_B = torch.optim.Adam(self.D_B.parameters(), lr=self.learning_rate, betas=(self.B1, self.B2)) # Learning rate update schedulers lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( optimizer_G, lr_lambda=LambdaLRSteper(self.n_epochs, self.start_epoch, self.epoch_decay).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( optimizer_D_A, lr_lambda=LambdaLRSteper(self.n_epochs, self.start_epoch, self.epoch_decay).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( optimizer_D_B, lr_lambda=LambdaLRSteper(self.n_epochs, self.start_epoch, self.epoch_decay).step) return [optimizer_G, optimizer_D_A, optimizer_D_B ], [lr_scheduler_G, lr_scheduler_D_A, lr_scheduler_D_B] @pl.data_loader def train_dataloader(self): return self.create_data_loader(MODE_TRAIN) @pl.data_loader def val_dataloader(self): return self.create_data_loader(MODE_VAL) @pl.data_loader def test_dataloader(self): return self.create_data_loader(MODE_TEST) def create_data_loader(self, mode): return DataLoader( ImageDataset(mode), batch_size=self.batch_size, shuffle=True, # num_workers=multiprocessing.cpu_count(), ) def sample_network_images(self, batch): """Saves a generated sample from the test set""" real_A = batch["A"].to(self.device) real_B = batch["B"].to(self.device) fake_B, fake_A = self.forward(real_A, real_B) # Arrange images along x-axis real_A = make_grid(real_A, nrow=5, normalize=True) real_B = make_grid(real_B, nrow=5, normalize=True) fake_A = make_grid(fake_A, nrow=5, normalize=True) fake_B = make_grid(fake_B, nrow=5, normalize=True) # Arrange images along y-axis image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1) self.logger.experiment.add_image(f'sample_images_{self.current_epoch}', image_grid, 0)
if transition_coef <= 1. and transition: generator = Generator(config['sr'][level_index], config, transition=transition, transition_coef=transition_coef).to(device) critic = Discriminator(config['sr'][level_index], config, transition=transition, transition_coef=transition_coef).to(device) generator_optimizer = torch.optim.Adam( generator.parameters(), lr=lr_gen, betas=config['generator_betas']) critic_optimizer = torch.optim.Adam(critic.parameters(), lr=lr_dis, betas=config['critic_betas']) elif transition_coef >= 1. and transition: transition = False generator = Generator(config['sr'][level_index], config, transition=transition, transition_coef=transition_coef).to(device) critic = Discriminator(config['sr'][level_index], config, transition=transition, transition_coef=transition_coef).to(device)
class Hidden: def __init__(self, configuration: HiDDenConfiguration, device: torch.device, noiser: Noiser, tb_logger): """ :param configuration: Configuration for the net, such as the size of the input image, number of channels in the intermediate layers, etc. :param device: torch.device object, CPU or GPU :param noiser: Object representing stacked noise layers. :param tb_logger: Optional TensorboardX logger object, if specified -- enables Tensorboard logging """ super(Hidden, self).__init__() self.encoder_decoder = EncoderDecoder(configuration, noiser).to(device) self.optimizer_enc_dec = torch.optim.Adam( self.encoder_decoder.parameters()) self.discriminator = Discriminator(configuration).to(device) self.optimizer_discrim = torch.optim.Adam( self.discriminator.parameters()) if configuration.use_vgg: self.vgg_loss = VGGLoss(3, 1, False) self.vgg_loss.to(device) else: self.vgg_loss = None self.config = configuration self.device = device self.bce_with_logits_loss = nn.BCEWithLogitsLoss() self.mse_loss = nn.MSELoss() # Defined the labels used for training the discriminator/adversarial loss self.cover_label = 1 self.encoded_label = 0 self.tb_logger = tb_logger if tb_logger is not None: from tensorboard_logger import TensorBoardLogger encoder_final = self.encoder_decoder.encoder._modules[ 'final_layer'] encoder_final.weight.register_hook( tb_logger.grad_hook_by_name('grads/encoder_out')) decoder_final = self.encoder_decoder.decoder._modules['linear'] decoder_final.weight.register_hook( tb_logger.grad_hook_by_name('grads/decoder_out')) discrim_final = self.discriminator._modules['linear'] discrim_final.weight.register_hook( tb_logger.grad_hook_by_name('grads/discrim_out')) def train_on_batch(self, batch: list): """ Trains the network on a single batch consisting of images and messages :param batch: batch of training data, in the form [images, messages] :return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch """ images, messages = batch batch_size = images.shape[0] with torch.enable_grad(): # ---------------- Train the discriminator ----------------------------- self.optimizer_discrim.zero_grad() # train on cover d_target_label_cover = torch.full((batch_size, 1), self.cover_label, device=self.device) d_on_cover = self.discriminator(images) d_loss_on_cover = self.bce_with_logits_loss( d_on_cover, d_target_label_cover) d_loss_on_cover.backward() # train on fake encoded_images, noised_images, decoded_messages = self.encoder_decoder( images, messages) d_target_label_encoded = torch.full((batch_size, 1), self.encoded_label, device=self.device) d_on_encoded = self.discriminator(encoded_images.detach()) d_loss_on_encoded = self.bce_with_logits_loss( d_on_encoded, d_target_label_encoded) d_loss_on_encoded.backward() self.optimizer_discrim.step() # --------------Train the generator (encoder-decoder) --------------------- self.optimizer_enc_dec.zero_grad() # target label for encoded images should be 'cover', because we want to fool the discriminator g_target_label_encoded = torch.full((batch_size, 1), self.cover_label, device=self.device) d_on_encoded_for_enc = self.discriminator(encoded_images) g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc, g_target_label_encoded) if self.vgg_loss == None: g_loss_enc = self.mse_loss(encoded_images, images) else: vgg_on_cov = self.vgg_loss(images) vgg_on_enc = self.vgg_loss(encoded_images) g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc) g_loss_dec = self.mse_loss(decoded_messages, messages) g_loss = self.config.adversarial_loss * g_loss_adv + self.config.encoder_loss * g_loss_enc \ + self.config.decoder_loss * g_loss_dec g_loss.backward() self.optimizer_enc_dec.step() decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip( 0, 1) bitwise_avg_err = np.sum( np.abs(decoded_rounded - messages.detach().cpu().numpy())) / ( batch_size * messages.shape[1]) losses = { 'loss ': g_loss.item(), 'encoder_mse ': g_loss_enc.item(), 'dec_mse ': g_loss_dec.item(), 'bitwise-error ': bitwise_avg_err, 'adversarial_bce': g_loss_adv.item(), 'discr_cover_bce': d_loss_on_cover.item(), 'discr_encod_bce': d_loss_on_encoded.item() } return losses, (encoded_images, noised_images, decoded_messages) def validate_on_batch(self, batch: list): """ Runs validation on a single batch of data consisting of images and messages :param batch: batch of validation data, in form [images, messages] :return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch """ # if TensorboardX logging is enabled, save some of the tensors. if self.tb_logger is not None: encoder_final = self.encoder_decoder.encoder._modules[ 'final_layer'] self.tb_logger.add_tensor('weights/encoder_out', encoder_final.weight) decoder_final = self.encoder_decoder.decoder._modules['linear'] self.tb_logger.add_tensor('weights/decoder_out', decoder_final.weight) discrim_final = self.discriminator._modules['linear'] self.tb_logger.add_tensor('weights/discrim_out', discrim_final.weight) images, messages = batch batch_size = images.shape[0] with torch.no_grad(): d_on_cover = self.discriminator(images) d_target_label_cover = torch.full((batch_size, 1), self.cover_label, device=self.device) d_on_cover = self.discriminator(images) d_loss_on_cover = self.bce_with_logits_loss( d_on_cover, d_target_label_cover) encoded_images, noised_images, decoded_messages = self.encoder_decoder( images, messages) d_target_label_encoded = torch.full((batch_size, 1), self.encoded_label, device=self.device) d_on_encoded = self.discriminator(encoded_images) d_loss_on_encoded = self.bce_with_logits_loss( d_on_encoded, d_target_label_encoded) g_target_label_encoded = torch.full((batch_size, 1), self.cover_label, device=self.device) d_on_encoded_for_enc = self.discriminator(encoded_images) g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc, g_target_label_encoded) if self.vgg_loss == None: g_loss_enc = self.mse_loss(encoded_images, images) else: vgg_on_cov = self.vgg_loss(images) vgg_on_enc = self.vgg_loss(encoded_images) g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc) g_loss_dec = self.mse_loss(decoded_messages, messages) g_loss = self.config.adversarial_loss * g_loss_adv + self.config.encoder_loss * g_loss_enc \ + self.config.decoder_loss * g_loss_dec decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip( 0, 1) bitwise_avg_err = np.sum( np.abs(decoded_rounded - messages.detach().cpu().numpy())) / ( batch_size * messages.shape[1]) losses = { 'loss ': g_loss.item(), 'encoder_mse ': g_loss_enc.item(), 'dec_mse ': g_loss_dec.item(), 'bitwise-error ': bitwise_avg_err, 'adversarial_bce': g_loss_adv.item(), 'discr_cover_bce': d_loss_on_cover.item(), 'discr_encod_bce': d_loss_on_encoded.item() } return losses, (encoded_images, noised_images, decoded_messages) def to_stirng(self): return '{}\n{}'.format(str(self.encoder_decoder), str(self.discriminator))
def main(data_dir): origin_img, uv_map_gt, uv_map_predicted = None, None, None if not os.path.exists(FLAGS['images']): os.mkdir(FLAGS['images']) # 1) Create Dataset of 300_WLP & Dataloader. wlp300 = PRNetDataset(root_dir=data_dir, transform=transforms.Compose([ ToTensor(), ToResize((416, 416)), ToNormalize(FLAGS["normalize_mean"], FLAGS["normalize_std"]) ])) wlp300_dataloader = DataLoader(dataset=wlp300, batch_size=FLAGS['batch_size'], shuffle=True, num_workers=1) # 2) Intermediate Processing. transform_img = transforms.Compose([ #transforms.ToTensor(), transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"]) ]) # 3) Create PRNet model. start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch'] g_x = ResFCN256(resolution_input=416, resolution_output=416, channel=3, size=16) g_y = ResFCN256(resolution_input=416, resolution_output=416, channel=3, size=16) d_x = Discriminator() d_y = Discriminator() # Load the pre-trained weight if FLAGS['resume'] != "" and os.path.exists( os.path.join(FLAGS['pretrained'], FLAGS['resume'])): state = torch.load(os.path.join(FLAGS['pretrained'], FLAGS['resume'])) try: g_x.load_state_dict(state['g_x']) g_y.load_state_dict(state['g_y']) d_x.load_state_dict(state['d_x']) d_y.load_state_dict(state['d_y']) except Exception: g_x.load_state_dict(state['prnet']) start_epoch = state['start_epoch'] INFO("Load the pre-trained weight! Start from Epoch", start_epoch) else: start_epoch = 0 INFO( "Pre-trained weight cannot load successfully, train from scratch!") if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) g_x.to(FLAGS["device"]) g_y.to(FLAGS["device"]) d_x.to(FLAGS["device"]) d_y.to(FLAGS["device"]) optimizer_g = torch.optim.Adam(itertools.chain(g_x.parameters(), g_y.parameters()), lr=FLAGS["lr"], betas=(0.5, 0.999)) optimizer_d = torch.optim.Adam(itertools.chain(d_x.parameters(), d_y.parameters()), lr=FLAGS["lr"]) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.99) stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"]) loss = WeightMaskLoss(mask_path=FLAGS["mask_path"]) bce_loss = torch.nn.BCEWithLogitsLoss() bce_loss.to(FLAGS["device"]) l1_loss = nn.L1Loss().to(FLAGS["device"]) lambda_X = 10 lambda_Y = 10 #Loss function for adversarial for ep in range(start_epoch, target_epoch): bar = tqdm(wlp300_dataloader) loss_list_cycle_x = [] loss_list_cycle_y = [] loss_list_d_x = [] loss_list_d_y = [] real_label = torch.ones(FLAGS['batch_size']) fake_label = torch.zeros(FLAGS['batch_size']) for i, sample in enumerate(bar): real_y, real_x = sample['uv_map'].to( FLAGS['device']), sample['origin'].to(FLAGS['device']) # x -> y' -> x^ optimizer_g.zero_grad() fake_y = g_x(real_x) prediction = d_x(fake_y) loss_g_x = bce_loss(prediction, real_label) x_hat = g_y(fake_y) loss_cycle_x = l1_loss(x_hat, real_x) * lambda_X loss_x = loss_g_x + loss_cycle_x loss_x.backward(retain_graph=True) optimizer_g.step() loss_list_cycle_x.append(loss_x.item()) # y -> x' -> y^ optimizer_g.zero_grad() fake_x = g_y(real_y) prediction = d_y(fake_x) loss_g_y = bce_loss(prediction, real_label) y_hat = g_x(fake_x) loss_cycle_y = l1_loss(y_hat, real_y) * lambda_Y loss_y = loss_g_y + loss_cycle_y loss_y.backward(retain_graph=True) optimizer_g.step() loss_list_cycle_y.append(loss_y.item()) # d_x optimizer_d.zero_grad() pred_real = d_x(real_y) loss_d_x_real = bce_loss(pred_real, real_label) pred_fake = d_x(fake_y) loss_d_x_fake = bce_loss(pred_fake, fake_label) loss_d_x = (loss_d_x_real + loss_d_x_fake) * 0.5 loss_d_x.backward() loss_list_d_x.append(loss_d_x.item()) optimizer_d.step() if 'WGAN' in FLAGS['gan_type']: for p in d_x.parameters(): p.data.clamp_(-1, 1) # d_y optimizer_d.zero_grad() pred_real = d_y(real_x) loss_d_y_real = bce_loss(pred_real, real_label) pred_fake = d_y(fake_x) loss_d_y_fake = bce_loss(pred_fake, fake_label) loss_d_y = (loss_d_y_real + loss_d_y_fake) * 0.5 loss_d_y.backward() loss_list_d_y.append(loss_d_y.item()) optimizer_d.step() if 'WGAN' in FLAGS['gan_type']: for p in d_y.parameters(): p.data.clamp_(-1, 1) if ep % FLAGS["save_interval"] == 0: with torch.no_grad(): print( " {} [Loss_G_X] {} [Loss_G_Y] {} [Loss_D_X] {} [Loss_D_Y] {}" .format(ep, loss_list_g_x[-1], loss_list_g_y[-1], loss_list_d_x[-1], loss_list_d_y[-1])) origin = cv2.imread("./test_data/obama_origin.jpg") gt_uv_map = np.load("./test_data/test_obama.npy") origin, gt_uv_map = test_data_preprocess( origin), test_data_preprocess(gt_uv_map) origin, gt_uv_map = transform_img(origin), transform_img( gt_uv_map) origin_in = origin.unsqueeze_(0).cuda() pred_uv_map = g_x(origin_in).detach().cpu() save_image( [origin.cpu(), gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map], os.path.join(FLAGS['images'], str(ep) + '.png'), nrow=1, normalize=True) # Save model print("Save model") state = { 'g_x': g_x.state_dict(), 'g_y': g_y.state_dict(), 'd_x': d_x.state_dict(), 'd_y': d_y.state_dict(), 'start_epoch': ep, } torch.save(state, os.path.join(FLAGS['images'], '{}.pth'.format(ep))) scheduler.step()
def train(train_sources, eval_source): path = sys.argv[1] dr = DataReader(path, train_sources) dr.read() print(len(dr.train.x)) batch_size = 8 device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device('cuda') dataset_s_train = MultiDomainDataset(dr.train.x, dr.train.y, dr.train.vendor, device, DomainAugmentation()) dataset_s_dev = MultiDomainDataset(dr.dev.x, dr.dev.y, dr.dev.vendor, device) dataset_s_test = MultiDomainDataset(dr.test.x, dr.test.y, dr.test.vendor, device) loader_s_train = DataLoader(dataset_s_train, batch_size, shuffle=True) dr_eval = DataReader(path, [eval_source]) dr_eval.read() dataset_eval_dev = MultiDomainDataset(dr_eval.dev.x, dr_eval.dev.y, dr_eval.dev.vendor, device) dataset_eval_test = MultiDomainDataset(dr_eval.test.x, dr_eval.test.y, dr_eval.test.vendor, device) dataset_da_train = MultiDomainDataset(dr.train.x+dr_eval.train.x, dr.train.y+dr_eval.train.y, dr.train.vendor+dr_eval.train.vendor, device, DomainAugmentation()) loader_da_train = DataLoader(dataset_da_train, batch_size, shuffle=True) segmentator = UNet() discriminator = Discriminator(n_domains=len(train_sources)) discriminator.to(device) segmentator.to(device) sigmoid = nn.Sigmoid() selector = Selector() s_criterion = nn.BCELoss() d_criterion = nn.CrossEntropyLoss() s_optimizer = optim.AdamW(segmentator.parameters(), lr=0.0001, weight_decay=0.01) d_optimizer = optim.AdamW(discriminator.parameters(), lr=0.001, weight_decay=0.01) a_optimizer = optim.AdamW(segmentator.encoder.parameters(), lr=0.001, weight_decay=0.01) lmbd = 1/150 s_train_losses = [] s_dev_losses = [] d_train_losses = [] eval_domain_losses = [] train_dices = [] dev_dices = [] eval_dices = [] epochs = 3 da_loader_iter = iter(loader_da_train) for epoch in tqdm(range(epochs)): s_train_loss = 0.0 d_train_loss = 0.0 for index, sample in enumerate(loader_s_train): img = sample['image'] target_mask = sample['target'] da_sample = next(da_loader_iter, None) if epoch == 100: s_optimizer.defaults['lr'] = 0.001 d_optimizer.defaults['lr'] = 0.0001 if da_sample is None: da_loader_iter = iter(loader_da_train) da_sample = next(da_loader_iter, None) if epoch < 50 or epoch >= 100: # Training step of segmentator predicted_activations, inner_repr = segmentator(img) predicted_mask = sigmoid(predicted_activations) s_loss = s_criterion(predicted_mask, target_mask) s_optimizer.zero_grad() s_loss.backward() s_optimizer.step() s_train_loss += s_loss.cpu().detach().numpy() if epoch >= 50: # Training step of discriminator predicted_activations, inner_repr = segmentator(da_sample['image']) predicted_activations = predicted_activations.clone().detach() inner_repr = inner_repr.clone().detach() predicted_vendor = discriminator(predicted_activations, inner_repr) d_loss = d_criterion(predicted_vendor, da_sample['vendor']) d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() d_train_loss += d_loss.cpu().detach().numpy() if epoch >= 100: # adversarial training step predicted_mask, inner_repr = segmentator(da_sample['image']) predicted_vendor = discriminator(predicted_mask, inner_repr) a_loss = -1 * lmbd * d_criterion(predicted_vendor, da_sample['vendor']) a_optimizer.zero_grad() a_loss.backward() a_optimizer.step() lmbd += 1/150 inference_model = nn.Sequential(segmentator, selector, sigmoid) inference_model.to(device) inference_model.eval() d_train_losses.append(d_train_loss / len(loader_s_train)) s_train_losses.append(s_train_loss / len(loader_s_train)) s_dev_losses.append(calculate_loss(dataset_s_dev, inference_model, s_criterion, batch_size)) eval_domain_losses.append(calculate_loss(dataset_eval_dev, inference_model, s_criterion, batch_size)) train_dices.append(calculate_dice(inference_model, dataset_s_train)) dev_dices.append(calculate_dice(inference_model, dataset_s_dev)) eval_dices.append(calculate_dice(inference_model, dataset_eval_dev)) segmentator.train() date_time = datetime.now().strftime("%m%d%Y_%H%M%S") model_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "model", "weights", "segmentator"+str(date_time)+".pth") torch.save(segmentator.state_dict(), model_path) util.plot_data([(s_train_losses, 'train_losses'), (s_dev_losses, 'dev_losses'), (d_train_losses, 'discriminator_losses'), (eval_domain_losses, 'eval_domain_losses')], 'losses.png') util.plot_dice([(train_dices, 'train_dice'), (dev_dices, 'dev_dice'), (eval_dices, 'eval_dice')], 'dices.png') inference_model = nn.Sequential(segmentator, selector, sigmoid) inference_model.to(device) inference_model.eval() print('Dice on annotated: ', calculate_dice(inference_model, dataset_s_test)) print('Dice on unannotated: ', calculate_dice(inference_model, dataset_eval_test))
def main(): """Main function that trains and/or evaluates a model.""" params = interpret_args() if params.gan: assert params.max_gen_len == params.train_maximum_sql_length \ == params.eval_maximum_sql_length data = atis_data.ATISDataset(params) generator = SchemaInteractionATISModel(params, data.input_vocabulary, data.output_vocabulary, data.output_vocabulary_schema, None) generator = generator.cuda() generator.build_optim() if params.gen_from_ckp: gen_ckp_path = os.path.join(params.logdir, params.gen_pretrain_ckp) if params.fine_tune_bert: gen_epoch, generator, generator.trainer, \ generator.bert_trainer = \ load_ckp( gen_ckp_path, generator, generator.trainer, generator.bert_trainer ) else: gen_epoch, generator, generator.trainer, _ = \ load_ckp( gen_ckp_path, generator, generator.trainer ) else: gen_epoch = 0 print('====================Model Parameters====================') print('=======================Generator========================') for name, param in generator.named_parameters(): print(name, param.requires_grad, param.is_cuda, param.size()) assert param.is_cuda print('==================Optimizer Parameters==================') print('=======================Generator========================') for param_group in generator.trainer.param_groups: print(param_group.keys()) for param in param_group['params']: print(param.size()) if params.fine_tune_bert: print('=========================BERT===========================') for param_group in generator.bert_trainer.param_groups: print(param_group.keys()) for param in param_group['params']: print(param.size()) sys.stdout.flush() # Pre-train generator with MLE if params.train: print('=============== Pre-training generator! ================') train(generator, data, params, gen_epoch) print('=========== Pre-training generator complete! ===========') dis_filter_sizes = [i for i in range(1, params.max_gen_len, 4)] dis_num_filters = [(100 + i * 10) for i in range(1, params.max_gen_len, 4)] discriminator = Discriminator(params, data.dis_src_vocab, data.dis_tgt_vocab, params.max_gen_len, params.num_dis_classes, dis_filter_sizes, dis_num_filters, params.max_pos_emb, params.num_tok_type, params.dis_dropout) discriminator = discriminator.cuda() dis_criterion = nn.NLLLoss(reduction='mean') dis_criterion = dis_criterion.cuda() dis_optimizer = optim.Adam(discriminator.parameters()) if params.dis_from_ckp: dis_ckp_path = os.path.join(params.logdir, params.dis_pretrain_ckp) dis_epoch, discriminator, dis_optimizer, _ = load_ckp( dis_ckp_path, discriminator, dis_optimizer) else: dis_epoch = 0 print('====================Model Parameters====================') print('=====================Discriminator======================') for name, param in discriminator.named_parameters(): print(name, param.requires_grad, param.is_cuda, param.size()) assert param.is_cuda print('==================Optimizer Parameters==================') print('=====================Discriminator======================') for param_group in dis_optimizer.param_groups: print(param_group.keys()) for param in param_group['params']: print(param.size()) sys.stdout.flush() # Pre-train discriminator if params.pretrain_discriminator: print('============= Pre-training discriminator! ==============') pretrain_discriminator(params, generator, discriminator, dis_criterion, dis_optimizer, data, start_epoch=dis_epoch) print('========= Pre-training discriminator complete! =========') # Adversarial Training if params.adversarial_training: print('================ Adversarial training! =================') generator.build_optim() dis_criterion = nn.NLLLoss(reduction='mean') dis_optimizer = optim.Adam(discriminator.parameters()) dis_criterion = dis_criterion.cuda() if params.adv_from_ckp and params.mle is not "mixed_mle": adv_ckp_path = os.path.join(params.logdir, params.adv_ckp) if params.fine_tune_bert: epoch, batches, pos_in_batch, generator, discriminator, \ generator.trainer, dis_optimizer, \ generator.bert_trainer, _, _ = \ load_adv_ckp( adv_ckp_path, generator, discriminator, generator.trainer, dis_optimizer, generator.bert_trainer) else: epoch, batches, pos_in_batch, generator, discriminator, \ generator.trainer, dis_optimizer, _, _, _ = \ load_adv_ckp( adv_ckp_path, generator, discriminator, generator.trainer, dis_optimizer) adv_train(generator, discriminator, dis_criterion, dis_optimizer, data, params, start_epoch=epoch, start_batches=batches, start_pos_in_batch=pos_in_batch) elif params.adv_from_ckp and params.mle == "mixed_mle": adv_ckp_path = os.path.join(params.logdir, params.adv_ckp) if params.fine_tune_bert: epoch, batches, pos_in_batch, generator, discriminator, \ generator.trainer, dis_optimizer, \ generator.bert_trainer, clamp, length = \ load_adv_ckp( adv_ckp_path, generator, discriminator, generator.trainer, dis_optimizer, generator.bert_trainer, mle=True) else: epoch, batches, pos_in_batch, generator, discriminator, \ generator.trainer, dis_optimizer, _, clamp, length = \ load_adv_ckp( adv_ckp_path, generator, discriminator, generator.trainer, dis_optimizer, mle=True) mixed_mle(generator, discriminator, dis_criterion, dis_optimizer, data, params, start_epoch=epoch, start_batches=batches, start_pos_in_batch=pos_in_batch, start_clamp=clamp, start_len=length) else: if params.mle == 'mixed_mle': mixed_mle(generator, discriminator, dis_criterion, dis_optimizer, data, params) else: adv_train(generator, discriminator, dis_criterion, dis_optimizer, data, params) if params.evaluate and 'valid' in params.evaluate_split: print("================== Evaluating! ===================") evaluate(generator, data, params, split='valid') print("============= Evaluation finished! ===============")
class Trainer: def __init__(self, nc=1, nz=100, ngf=64, ndf=64, lr=0.0002, beta1=0.5, ngpu=1, autosave=None): self.nz = nz self.dataloader = None self.img_list = None self.G_losses = None self.D_losses = None self.device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu") self.netG = Generator(nz, ngf, nc).to(self.device) self.netG.apply(utils.weights_init) self.netD = Discriminator(nc, ndf).to(self.device) self.netD.apply(utils.weights_init) self.criterion = nn.BCELoss() self.fixed_noise = torch.randn(64, nz, 1, 1, device=self.device) self.real_label = 1. self.fake_label = 0. self.optimizerD = optim.Adam(self.netD.parameters(), lr=lr, betas=(beta1, 0.999)) self.optimizerG = optim.Adam(self.netG.parameters(), lr=lr, betas=(beta1, 0.999)) self.result_path = autosave # Set the dataloader def load_data(self, dataloader): self.dataloader = dataloader print("Dataloader is prepared!") # Train the networks def train(self, num_epochs, render=False): # Check whether the dataloader is None, if so quit the training if self.dataloader is None: print("Data has not been loaded yet!") return # Else start the training print("Start training loop...") self.netG.apply(utils.weights_init) self.netD.apply(utils.weights_init) self.img_list = [] self.G_losses = [] self.D_losses = [] for epoch in range(num_epochs): # For each batch in the dataloader for i, data in enumerate(self.dataloader, 0): # Train the discriminator with real images self.netD.zero_grad() real_cpu = data[0].to(self.device) b_size = real_cpu.size(0) label = torch.full((b_size,), self.real_label, dtype=torch.float, device=self.device) output = self.netD(real_cpu).view(-1) errD_real = self.criterion(output, label) errD_real.backward() D_x = output.mean().item() # Train the discriminator with fake images noise = torch.randn(b_size, self.nz, 1, 1, device=self.device) fake = self.netG(noise) label.fill_(self.fake_label) output = self.netD(fake.detach()).view(-1) errD_fake = self.criterion(output, label) errD_fake.backward() D_G_z1 = output.mean().item() # Update the parameters errD = errD_real + errD_fake self.optimizerD.step() # Train the generator self.netG.zero_grad() label.fill_(self.real_label) output = self.netD(fake).view(-1) errG = self.criterion(output, label) errG.backward() D_G_z2 = output.mean().item() self.optimizerG.step() if i % 100 == 0: print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch, num_epochs, i, len(self.dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) self.G_losses.append(errG.item()) self.D_losses.append(errD.item()) with torch.no_grad(): # Fixed noise is employed here, because I want to generate images of same digits for more conspicuous comparisons fake = self.netG(self.fixed_noise).detach().cpu() self.img_list.append(vutils.make_grid(fake, padding=2, normalize=True)) # Plot the training result of this epoch self.draw_current_image(epoch, show=render) # Draw the original image def draw_original_image(self): real_batch = next(iter(self.dataloader)) plt.figure(figsize=(8, 8)) plt.axis("off") plt.title("Training Images") plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(self.device)[:64], padding=2, normalize=True).cpu(), (1, 2, 0))) if self.result_path is not None: plt.savefig(self.result_path + "\\origin.png") plt.show() # Plot the loss curves of both the generator and the discriminator def plot_loss(self): plt.figure(figsize=(10, 5)) plt.title("Generator and Discriminator Loss During Training") plt.plot(self.G_losses, label="G") plt.plot(self.D_losses, label="D") plt.xlabel("iterations") plt.ylabel("Loss") plt.legend() if self.result_path is not None: plt.savefig(self.result_path + "\\loss.png") plt.show() # Draw the last 64 figures def draw_current_image(self, current_epoch, show=False): plt.figure(figsize=(8, 8)) plt.axis("off") plt.title("Fake Images_" + str(current_epoch)) plt.imshow(np.transpose(self.img_list[-1], (1, 2, 0))) if self.result_path is not None: plt.savefig(self.result_path + "\\fake_" + str(current_epoch) + ".png") if show: plt.show()
class RefSRSolver(BaseSolver): def __init__(self, cfg): super(RefSRSolver, self).__init__(cfg) self.srntt = SRNTT(cfg['model']['n_resblocks'], cfg['schedule']['use_weights'], cfg['schedule']['concat']).cuda() # self.discriminator = None self.discriminator = Discriminator(cfg['data']['input_size']).cuda() # self.vgg = None self.vgg = VGG19(cfg['model']['final_layer'], cfg['model']['prev_layer'], True).cuda() params = list(self.srntt.texture_transfer.parameters()) + list(self.srntt.texture_fusion_medium.parameters()) +\ list(self.srntt.texture_fusion_large.parameters()) + list(self.srntt.srntt_out.parameters()) self.init_epoch = self.cfg['schedule']['init_epoch'] self.num_epochs = self.cfg['schedule']['num_epochs'] self.optimizer_init = torch.optim.Adam(params, lr=cfg['schedule']['lr']) self.optimizer = torch.optim.lr_scheduler.MultiStepLR( torch.optim.Adam(params, lr=cfg['schedule']['lr']), [self.num_epochs // 2], 0.1) self.optimizer_d = torch.optim.lr_scheduler.MultiStepLR( torch.optim.Adam(self.discriminator.parameters(), lr=cfg['schedule']['lr']), [self.num_epochs // 2], 0.1) self.reconst_loss = nn.L1Loss() self.bp_loss = BackProjectionLoss() self.texture_loss = TextureLoss(self.cfg['schedule']['use_weights'], 80) self.adv_loss = AdvLoss(self.cfg['schedule']['is_WGAN_GP']) self.loss_weights = self.cfg['schedule']['loss_weights'] def train(self): if self.epoch <= self.init_epoch: with tqdm(total=len(self.train_loader), miniters=1, desc='Initial Training Epoch: [{}/{}]'.format( self.epoch, self.max_epochs)) as t: for data in self.train_loader: lr, hr = data['lr'].cuda(), data['hr'].cuda() maps, weight = data['map'].cuda(), data['weight'].cuda() self.srntt.train() self.optimizer_init.zero_grad() sr, srntt_out = self.srntt(lr, weight, maps) loss_reconst = self.reconst_loss(sr, hr) loss_bp = self.bp_loss(lr, srntt_out) loss = self.loss_weights[ 4] * loss_reconst + self.loss_weights[3] * loss_bp t.set_postfix_str("Batch loss {:.4f}".format(loss.item())) t.update() loss.backward() self.optimizer_init.step() elif self.epoch <= self.num_epochs: with tqdm(total=len(self.train_loader), miniters=1, desc='Complete Training Epoch: [{}/{}]'.format( self.epoch, self.max_epochs)) as t: for data in self.train_loader: lr, hr = data['lr'].cuda(), data['hr'].cuda() maps, weight = data['map'].cuda(), data['weight'].cuda() self.srntt.train() self.optimizer_init.zero_grad() self.optimizer.optimizer.zero_grad() self.optimizer_d.optimizer.zero_grad() sr, srntt_out = self.srntt(lr, weight, maps) sr_prevlayer, sr_lastlayer = self.vgg(srntt_out) hr_prevlayer, hr_lastlayer = self.vgg(hr) _, d_real_logits = self.discriminator(hr) _, d_fake_logits = self.discriminator(srntt_out) loss_reconst = self.reconst_loss(sr, hr) loss_bp = self.bp_loss(lr, srntt_out) loss_texture = self.texture_loss(sr_prevlayer, maps, weight) loss_d, loss_g = self.adv_loss(srntt_out, hr, d_fake_logits, d_real_logits, self.discriminator) loss_percep = torch.pow(sr_lastlayer - hr_lastlayer, 2).mean() if self.cfg['schedule']['use_lower_layers_in_per_loss']: for l_sr, l_hr in zip(sr_prevlayer, hr_prevlayer): loss_percep += torch.pow(l_sr - l_hr, 2).mean() loss_percep = loss_percep / (len(sr_prevlayer) + 1) weighted_loss = torch.Tensor(self.loss_weights).cuda() * \ torch.Tensor([loss_percep, loss_texture, loss_g, loss_bp, loss_reconst]) total_loss = weighted_loss.sum() t.set_postfix_str("Batch loss {:.4f}".format( total_loss.item())) t.update() loss_d.backward() total_loss.backward() self.optimizer.step(self.epoch) self.optimizer_d.step(self.epoch) else: pass def eval(self): with tqdm(total=len(self.val_loader), miniters=1, desc='Val Epoch: [{}/{}]'.format(self.epoch, self.max_epochs)) as t: psnr_list, ssim_list, loss_list = [], [], [] for lr, hr in self.val_loader: lr, hr = lr.cuda(), hr.cuda() self.srntt.eval() with torch.no_grad(): sr, _ = self.srntt(lr, None, None) loss = self.reconst_loss(sr, hr) batch_psnr, batch_ssim = [], [] for c in range(sr.shape[0]): predict_sr = (sr[c, ...].cpu().numpy().transpose( (1, 2, 0)) + 1) * 127.5 ground_truth = (hr[c, ...].cpu().numpy().transpose( (1, 2, 0)) + 1) * 127.5 psnr = utils.calculate_psnr(predict_sr, ground_truth, 255) ssim = utils.calculate_ssim(predict_sr, ground_truth, 255) batch_psnr.append(psnr) batch_ssim.append(ssim) avg_psnr = np.array(batch_psnr).mean() avg_ssim = np.array(batch_ssim).mean() psnr_list.extend(batch_psnr) ssim_list.extend(batch_ssim) t.set_postfix_str( 'Batch loss: {:.4f}, PSNR: {:.4f}, SSIM: {:.4f}'.format( loss.item(), avg_psnr, avg_ssim)) t.update() self.records['Epoch'].append(self.epoch) self.records['PSNR'].append(np.array(psnr_list).mean()) self.records['SSIM'].append(np.array(ssim_list).mean()) self.logger.log('Val Epoch {}: PSNR={}, SSIM={}'.format( self.epoch, self.records['PSNR'][-1], self.records['SSIM'][-1])) def save_checkpoint(self): super(RefSRSolver, self).save_checkpoint() self.ckp['srntt'] = self.srntt.state_dict() self.ckp['optimizer'] = self.optimizer.state_dict() self.ckp['optimizer_d'] = self.optimizer_d.state_dict() self.ckp['optimizer_init'] = self.optimizer_init.state_dict() if self.discriminator is not None: self.ckp['discriminator'] = self.discriminator.state_dict() if self.vgg is not None: self.ckp['vgg'] = self.vgg.state_dict() torch.save(self.ckp, os.path.join(self.checkpoint_dir, 'latest.pth')) if self.records['PSNR'][-1] == np.array(self.records['PSNR']).max(): shutil.copy(os.path.join(self.checkpoint_dir, 'latest.pth'), os.path.join(self.checkpoint_dir, 'best.pth')) def load_checkpoint(self, model_path): super(RefSRSolver, self).load_checkpoint(model_path) ckpt = torch.load(model_path) self.srntt.load_state_dict(ckpt['srntt']) self.optimizer.load_state_dict(ckpt['optimizer']) self.optimizer_d.load_state_dict(ckpt['optimizer_d']) self.optimizer_init.load_state_dict(ckpt['optimizer_init']) if 'vgg' in ckpt.keys() and self.vgg is not None: self.vgg.load_stat_dict(ckpt['srntt']) if 'discriminator' in ckpt.keys() and self.discriminator is not None: self.discriminator.load_state_dict(ckpt['discriminator'])
class Trial: def __init__(self, data_dir: str = './dataset', log_dir: str = './logs', device: str = "cuda:0", batch_size: int = 2, init_lr: float = 0.5, G_lr: float = 0.0004, D_lr: float = 0.0008, level: str = "O1", patch: bool = False, init_training_epoch: int = 10, train_epoch: int = 10, optim_type: str = "ADAM", pin_memory: bool = True, grad_set_to_none: bool = True): # self.config = config self.data_dir = data_dir self.dataset = Dataset(root=data_dir + "/Shinkai", style_transform=tr.transform, smooth_transform=tr.transform) self.pin_memory = pin_memory self.batch_size = batch_size self.dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=pin_memory) self.device = torch.device( device) if torch.cuda.is_available() else torch.device('cpu') self.G = Generator().to(self.device) self.patch = patch if self.patch: self.D = PatchDiscriminator().to(self.device) else: self.D = Discriminator().to(self.device) self.init_model_weights() self.optimizer_G = GANOptimizer(optim_type, self.G.parameters(), lr=G_lr, betas=(0.5, 0.999), amsgrad=False) self.optimizer_D = GANOptimizer(optim_type, self.D.parameters(), lr=D_lr, betas=(0.5, 0.999), amsgrad=True) self.loss = Loss(device=self.device).to(self.device) self.init_lr = init_lr self.G_lr = G_lr self.D_lr = D_lr self.grad_set_to_none = grad_set_to_none self.writer = tensorboard.SummaryWriter(log_dir=log_dir) self.init_train_epoch = init_training_epoch self.train_epoch = train_epoch self.init_time = None self.level = level if self.level != "O0" and device != "cpu": self.fp16 = True [self.G, self.D], [self.optimizer_G, self.optimizer_D ] = amp.initialize([self.G, self.D], [self.optimizer_G, self.optimizer_D], opt_level=self.level) else: self.fp16 = False def init_model_weights(self): self.G.apply(weights_init) self.D.apply(weights_init) @classmethod def from_config(cls): pass def init_train(self, con_weight: float = 1.0): test_img = self.get_test_image() meter = AverageMeter("Loss") self.writer.flush() lr_scheduler = OneCycleLR(self.optimizer_G, max_lr=0.9999, steps_per_epoch=len(self.dataloader), epochs=self.init_train_epoch) for g in self.optimizer_G.param_groups: g['lr'] = self.init_lr for epoch in tqdm(range(self.init_train_epoch)): meter.reset() for i, (style, smooth, train) in enumerate(self.dataloader, 0): # train = transform(test_img).unsqueeze(0) self.G.zero_grad(set_to_none=self.grad_set_to_none) train = train.to(self.device) generator_output = self.G(train) # content_loss = loss.reconstruction_loss(generator_output, train) * con_weight content_loss = self.loss.content_loss(generator_output, train) * con_weight # content_loss = F.mse_loss(train, generator_output) * con_weight content_loss.backward() self.optimizer_G.step() lr_scheduler.step() meter.update(content_loss.detach()) self.writer.add_scalar(f"Loss : {self.init_time}", meter.sum.item(), epoch) self.write_weights(epoch + 1, write_D=False) self.eval_image(epoch, f'{self.init_time} reconstructed img', test_img) for g in self.optimizer_G.param_groups: g['lr'] = self.G_lr # self.save_trial(self.init_train_epoch, "init") def eval_image(self, epoch: int, caption, img): """Feeds in one single image to process and save.""" self.G.eval() styled_test_img = tr.transform(img).unsqueeze(0).to(self.device) with torch.no_grad(): styled_test_img = self.G(styled_test_img) styled_test_img = styled_test_img.to('cpu').squeeze() self.write_image(styled_test_img, caption, epoch + 1) self.writer.flush() self.G.train() def write_image(self, image: torch.Tensor, img_caption: str = "sample_image", step: int = 0): image = torch.clip(tr.inv_norm(image).to(torch.float), 0, 1) # [-1, 1] -> [0, 1] image *= 255. # [0, 1] -> [0, 255] image = image.permute(1, 2, 0).to(dtype=torch.uint8) self.writer.add_image(img_caption, image, step, dataformats='HWC') self.writer.flush() def write_weights(self, epoch: int, write_D=True, write_G=True): if write_D: for name, weight in self.D.named_parameters(): if 'depthwise' in name or 'pointwise' in name: self.writer.add_histogram( f"Discriminator {name} {self.init_time}", weight, epoch) self.writer.add_histogram( f"Discriminator {name}.grad {self.init_time}", weight.grad, epoch) self.writer.flush() if write_G: for name, weight in self.G.named_parameters(): self.writer.add_histogram(f"Generator {name} {self.init_time}", weight, epoch) self.writer.add_histogram( f"Generator {name}.grad {self.init_time}", weight.grad, epoch) self.writer.flush() def train_1( self, adv_weight: float = 300., con_weight: float = 1.5, gra_weight: float = 3., col_weight: float = 10., ): test_img_dir = Path( self.data_dir).joinpath('test/test_photo256').resolve() test_img_dir = random.choice(list(test_img_dir.glob('**/*'))) test_img = Image.open(test_img_dir) self.writer.add_image(f'test image {self.init_time}', np.asarray(test_img), dataformats='HWC') self.writer.flush() for epoch in tqdm(range(self.train_epoch)): for i, (style, smooth, train) in enumerate(self.dataloader, 0): self.D.zero_grad() style = style.to(self.device) smooth = smooth.to(self.device) train = train.to(self.device) # style image to discriminator(Not Gram Matrix Loss) style_loss_value = self.D(style).view(-1) generator_output = self.G(train) # generated image to discriminator real_output = self.D(generator_output.detach()).view(-1) # greyscale_output = D(transforms.functional.rgb_to_grayscale(train, num_output_channels=3)).view(-1) #greyscale adversarial loss gray_train = tr.inv_gray_transform(train) greyscale_output = self.D(gray_train).view(-1) smoothed_loss = self.D(smooth).view(-1) # smoothed image loss # loss_D_real = adversarial_loss(output, label) dis_adv_loss = adv_weight * ( torch.pow(style_loss_value - 1, 2).mean() + torch.pow(real_output, 2).mean()) dis_gray_loss = torch.pow(greyscale_output, 2).mean() dis_edge_loss = torch.pow(smoothed_loss, 2).mean() discriminator_loss = dis_adv_loss + dis_gray_loss + dis_edge_loss discriminator_loss.backward() self.optimizer_D.step() if i % 200 == 0 and i != 0: self.writer.add_scalars( f'{self.init_time} Discriminator losses', { 'adversarial loss': dis_adv_loss.item(), 'grayscale loss': dis_gray_loss.item(), 'edge loss': dis_edge_loss.item() }, i + epoch * len(self.dataloader)) self.writer.flush() real_output = self.D(generator_output).view(-1) per_loss = self.loss.perceptual_loss( train, generator_output) # loss for G style_loss = self.loss.style_loss(generator_output, style) content_loss = self.loss.content_loss(generator_output, train) recon_loss = self.loss.reconstruction_loss( generator_output, train) tv_loss = self.loss.tv_loss(generator_output) ''' print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch, num_epoch, i, len(data_loader), loss_D.item(), loss_G.item(), D_x, D_G_z1, D_G_z2))''' self.G.zero_grad() gen_adv_loss = adv_weight * torch.pow(real_output - 1, 2).mean() gen_con_loss = con_weight * content_loss gen_sty_loss = gra_weight * style_loss gen_rec_loss = col_weight * recon_loss gen_per_loss = per_loss gen_tv_loss = tv_loss generator_loss = gen_adv_loss + gen_con_loss + gen_sty_loss + gen_rec_loss + gen_per_loss generator_loss.backward() self.optimizer_G.step() if i % 200 == 0 and i != 0: self.writer.add_scalars( f'generator losses {self.init_time}', { 'adversarial loss': gen_adv_loss.item(), 'content loss': gen_con_loss.item(), 'style loss': gen_sty_loss.item(), 'reconstruction loss': gen_rec_loss.item(), 'perceptual loss': gen_per_loss.item() }, i + epoch * len(self.dataloader)) self.writer.flush() self.write_weights(epoch + 1) self.eval_image(epoch, f'{self.init_time} style img', test_img) def train_2(self, adv_weight: float = 1.0, threshold: float = 3., G_train_iter: int = 1, D_train_iter: int = 1 ): # if threshold is 0., set to half of adversarial loss test_img_dir = Path(self.data_dir).joinpath('test', 'test_photo256') test_img_dir = random.choice(list(test_img_dir.glob('**/*'))) test_img = Image.open(test_img_dir) if self.init_time is None: self.init_time = datetime.datetime.now().strftime("%H:%M") self.writer.add_image(f'sample_image {self.init_time}', np.asarray(test_img), dataformats='HWC') self.writer.flush() perception_weight = 0. keep_constant = False for epoch in tqdm(range(self.train_epoch)): total_dis_loss = 0. for i, (style, smooth, train) in enumerate(self.dataloader, 0): self.D.zero_grad() train = train.to(self.device) style = style.to(self.device) # smooth = smooth.to(device) for _ in range(D_train_iter): style_loss_value = self.D(style).view(-1) generator_output = self.G(train) real_output = self.D(generator_output.detach()).view(-1) dis_adv_loss = adv_weight * \ (torch.pow(style_loss_value - 1, 2).mean() + torch.pow(real_output, 2).mean()) total_dis_loss += dis_adv_loss.item() dis_adv_loss.backward() self.optimizer_D.step() self.G.zero_grad() for _ in range(G_train_iter): generator_output = self.G(train) real_output = self.D(generator_output).view(-1) per_loss = perception_weight * \ self.loss.perceptual_loss(train, generator_output) gen_adv_loss = adv_weight * torch.pow(real_output - 1, 2).mean() gen_loss = gen_adv_loss + per_loss gen_loss.backward() self.optimizer_G.step() if i % 200 == 0 and i != 0: self.writer.add_scalars( f'generator losses {self.init_time}', { 'adversarial loss': dis_adv_loss.item(), 'Generator adversarial loss': gen_adv_loss.item(), 'perceptual loss': per_loss.item() }, i + epoch * len(self.dataloader)) self.writer.flush() if total_dis_loss > threshold and not keep_constant: perception_weight += 0.05 else: keep_constant = True self.writer.add_scalar( f'total discriminator loss {self.init_time}', total_dis_loss, i + epoch * len(self.dataloader)) self.write_weights() self.G.eval() styled_test_img = tr.transform(test_img).unsqueeze(0).to( self.device) with torch.no_grad(): styled_test_img = self.G(styled_test_img) styled_test_img = styled_test_img.to('cpu').squeeze() self.write_image(styled_test_img, f'styled image {self.init_time}', epoch + 1) self.G.train() def __call__(self): self.init_train() self.train_1() def save_trial(self, epoch: int, train_type: str): save_dir = Path(f"{train_type}_{self.level}.pt") training_details = { "epoch": epoch, "gen": { "gen_state_dict": self.G.state_dict(), "optim_G_state_dict": self.optimizer_G.state_dict() }, "dis": { "dis_state_dict": self.D.state_dict(), "optim_D_state_dict": self.optimizer_D.state_dict() } } if self.fp16: training_details["amp"] = amp.state_dict() torch.save(training_details, save_dir.as_posix()) def load_trial(self, dir: Path): assert dir.is_file(), "No such directory" assert dir.suffix == ".pt", "Filetype not compatible" state_dicts = torch.load(dir.as_posix()) self.G.load_state_dict(state_dicts["gen"]["gen_state_dict"]) self.optimizer_G.load_state_dict( state_dicts["gen"]["optim_G_state_dict"]) self.D.load_state_dict(state_dicts["dis"]["dis_state_dict"]) self.optimizer_D.load_state_dict( state_dicts["dis"]["optim_D_state_dict"]) if self.fp16: amp.load_state_dict(state_dicts["amp"]) typer.echo("Loaded Weights") def Generator_NOGAN(self, epochs: int = 1, style_weight: float = 20., content_weight: float = 1.2, recon_weight: float = 10., tv_weight: float = 1e-6, loss: List[str] = ['content_loss']): """Training Generator in NOGAN manner (Feature Loss only).""" for g in self.optimizer_G.param_groups: g['lr'] = self.G_lr test_img = self.get_test_image() max_lr = self.G_lr * 10. lr_scheduler = OneCycleLR(self.optimizer_G, max_lr=max_lr, steps_per_epoch=len(self.dataloader), epochs=epochs) meter = LossMeters(*loss) total_loss_arr = np.array([]) for epoch in tqdm(range(epochs)): total_losses = 0 meter.reset() for i, (style, smooth, train) in enumerate(self.dataloader, 0): # train = transform(test_img).unsqueeze(0) self.G.zero_grad(set_to_none=self.grad_set_to_none) train = train.to(self.device) generator_output = self.G(train) if 'style_loss' in loss: style = style.to(self.device) style_loss = self.loss.style_loss(generator_output, style) * style_weight else: style_loss = 0. if 'content_loss' in loss: content_loss = self.loss.content_loss( generator_output, train) * content_weight else: content_loss = 0. if 'recon_loss' in loss: recon_loss = self.loss.reconstruction_loss( generator_output, train) * recon_weight else: recon_loss = 0. if 'tv_loss' in loss: tv_loss = self.loss.tv_loss(generator_output) * tv_weight else: tv_loss = 0. total_loss = content_loss + tv_loss + recon_loss + style_loss if self.fp16: with amp.scale_loss(total_loss, self.optimizer_G) as scaled_loss: scaled_loss.backward() else: total_loss.backward() self.optimizer_G.step() lr_scheduler.step() total_losses += total_loss.detach() loss_dict = { 'content_loss': content_loss, 'style_loss': style_loss, 'recon_loss': recon_loss, 'tv_loss': tv_loss } losses = [loss_dict[loss_type].detach() for loss_type in loss] meter.update(*losses) total_loss_arr = np.append(total_loss_arr, total_losses.item()) self.writer.add_scalars(f'{self.init_time} NOGAN generator losses', meter.as_dict('sum'), epoch) self.write_weights(epoch + 1, write_D=False) self.eval_image(epoch, f'{self.init_time} reconstructed img', test_img) if epoch > 2: fig = plt.figure(figsize=(8, 8)) X = np.arange(len(total_loss_arr)) Y = np.gradient(total_loss_arr) plt.plot(X, Y) thresh = -1.0 plt.axhline(thresh, c='r') plt.title(f"{self.init_time}") self.writer.add_figure(f"{self.init_time}", fig, epoch) if Y[-1] > thresh: break self.save_trial(epoch, f'G_NG_{self.init_time}') def Discriminator_NOGAN( self, epochs: int = 3, adv_weight: float = 1.0, edge_weight: float = 1.0, loss: List[str] = ['real_adv_loss', 'fake_adv_loss', 'gray_loss']): """https://discuss.pytorch.org/t/scheduling-batch-size-in-dataloader/46443/2""" for g in self.optimizer_D.param_groups: g['lr'] = self.D_lr max_lr = self.D_lr * 10. lr_scheduler = OneCycleLR(self.optimizer_D, max_lr=max_lr, steps_per_epoch=len(self.dataloader), epochs=epochs) meter = LossMeters(*loss) total_loss_arr = np.array([]) if self.init_time is None: self.init_time = datetime.datetime.now().strftime("%H:%M") for epoch in tqdm(range(epochs)): meter.reset() for i, (style, smooth, train) in enumerate(self.dataloader, 0): # train = transform(test_img).unsqueeze(0) self.D.zero_grad(set_to_none=self.grad_set_to_none) train = train.to(self.device) style = style.to(self.device) generator_output = self.G(train) real_adv_loss = self.D(style).view(-1) fake_adv_loss = self.D(generator_output.detach()).view(-1) real_adv_loss = torch.pow(real_adv_loss - 1, 2).mean() * 1.7 * adv_weight fake_adv_loss = torch.pow(fake_adv_loss, 2).mean() * 1.7 * adv_weight gray_train = tr.inv_gray_transform(style) greyscale_output = self.D(gray_train).view(-1) gray_loss = torch.pow(greyscale_output, 2).mean() * 1.7 * adv_weight "According to AnimeGANv2 implementation, every loss is scaled by individual weights and then scaled with adv_weight" "https://github.com/TachibanaYoshino/AnimeGANv2/blob/5946b6afcca5fc28518b75a763c0f561ff5ce3d6/tools/ops.py#L217" total_loss = real_adv_loss + fake_adv_loss + gray_loss if self.fp16: with amp.scale_loss(total_loss, self.optimizer_D) as scaled_loss: scaled_loss.backward() else: total_loss.backward() self.optimizer_D.step() lr_scheduler.step() loss_dict = { 'real_adv_loss': real_adv_loss, 'fake_adv_loss': fake_adv_loss, 'gray_loss': gray_loss } losses = [loss_dict[loss_type].detach() for loss_type in loss] meter.update(*losses) self.writer.add_scalars( f'{self.init_time} NOGAN discriminator loss', meter.as_dict('sum'), epoch) self.writer.flush() if epoch > 2: fig = plt.figure(figsize=(8, 8)) X = np.arange(len(total_loss_arr)) Y = np.gradient(total_loss_arr) plt.plot(X, Y) thresh = -1.0 plt.axhline(thresh, c='r') plt.title(f"{self.init_time}") self.writer.add_figure(f"{self.init_time}", fig, epoch) if Y[-1] > thresh: break def GAN_NOGAN( self, epochs: int = 1, GAN_G_lr: float = 0.00008, GAN_D_lr: float = 0.000016, D_loss: List[str] = [ "real_adv_loss", "fake_adv_loss", "gray_loss", "edge_loss" ], adv_weight: float = 300., edge_weight: float = 0.1, G_loss: List[str] = [ "adv_loss", "content_loss", "style_loss", "recon_loss" ], style_weight: float = 20., content_weight: float = 1.2, recon_weight: float = 10., tv_weight: float = 1e-6, ): test_img = self.get_test_image() dis_meter = LossMeters(*D_loss) gen_meter = LossMeters(*G_loss) for g in self.optimizer_G.param_groups: g['lr'] = GAN_G_lr for g in self.optimizer_D.param_groups: g['lr'] = GAN_D_lr update_duration = len(self.dataloader) // 20 for epoch in tqdm(range(epochs)): G_loss_arr = np.array([]) dis_meter.reset() count = 0 for i, (style, smooth, train) in enumerate(self.dataloader, 0): self.D.zero_grad(set_to_none=self.grad_set_to_none) train = train.to(self.device) style = style.to(self.device) smooth = smooth.to(self.device) generator_output = self.G(train) real_adv_loss = self.D(style).view(-1) fake_adv_loss = self.D(generator_output.detach()).view(-1) G_adv_loss = self.D(generator_output).view(-1) gray_train = tr.inv_gray_transform(style) grayscale_output = self.D(gray_train).view(-1) gray_smooth_data = tr.inv_gray_transform(smooth) smoothed_output = self.D(smooth).view(-1) real_adv_loss = torch.square(real_adv_loss - 1.).mean() * 1.7 * adv_weight fake_adv_loss = torch.square( fake_adv_loss).mean() * 1.7 * adv_weight gray_loss = torch.square( grayscale_output).mean() * 1.7 * adv_weight edge_loss = torch.square( smoothed_output).mean() * 1.0 * adv_weight total_D_loss = real_adv_loss + fake_adv_loss + gray_loss + edge_loss total_D_loss.backward() self.optimizer_D.step() D_loss_dict = { 'real_adv_loss': real_adv_loss, 'fake_adv_loss': fake_adv_loss, 'gray_loss': gray_loss, 'edge_loss': edge_loss } loss = list(D_loss_dict.values()) dis_meter.update(*loss) if i % update_duration == 0 and i != 0: self.writer.add_scalars(f'{self.init_time} NOGAN Dis loss', dis_meter.as_dict('val'), i + epoch * len(self.dataloader)) self.writer.flush() self.G.zero_grad(set_to_none=self.grad_set_to_none) G_adv_loss = torch.square(G_adv_loss - 1.).mean() * adv_weight if 'style_loss' in G_loss: style_loss = self.loss.style_loss(generator_output, style) * style_weight else: style_loss = 0. if 'content_loss' in G_loss: content_loss = self.loss.content_loss( generator_output, train) * content_weight else: content_loss = 0. if 'recon_loss' in G_loss: recon_loss = self.loss.reconstruction_loss( generator_output, train) * recon_weight else: recon_loss = 0. if 'tv_loss' in G_loss: tv_loss = self.loss.tv_loss(generator_output) * tv_weight else: tv_loss = 0. total_G_loss = G_adv_loss + content_loss + tv_loss + recon_loss + style_loss total_G_loss.backward() self.optimizer_G.step() G_loss_dict = { 'adv_loss': G_adv_loss, 'content_loss': content_loss, 'style_loss': style_loss, 'recon_loss': recon_loss, 'tv_loss': tv_loss } losses = [ G_loss_dict[loss_type].detach() for loss_type in G_loss ] gen_meter.update(*losses) if i % update_duration == 0 and i != 0: self.writer.add_scalars(f'{self.init_time} NOGAN Gen loss', gen_meter.as_dict('val'), i + epoch * len(self.dataloader)) self.writer.flush() G_loss_arr = np.append(G_loss_arr, G_adv_loss.item()) self.eval_image(i + epoch * len(self.dataloader), f'{self.init_time} reconstructed img', test_img) self.save_trial(epoch, f'GAN_NG_{self.init_time}') def get_test_image(self): """Get random test image.""" test_img_dir = Path(self.data_dir).joinpath('test/test_photo256') test_img_dir = random.choice(list(test_img_dir.glob('**/*'))) test_img = Image.open(test_img_dir) self.init_time = datetime.datetime.now().strftime("%H:%M") self.writer.add_image(f'{self.init_time} sample_image', np.asarray(test_img), dataformats='HWC') self.writer.flush() return test_img
} cvae = CVAE(opts.latent_size, device).to(device) dis = Discriminator().to(device) classifier = Classifier(opts.latent_size).to(device) classer = CLASSIFIERS().to(device) print(cvae) print(dis) print(classifier) optimizer_cvae = torch.optim.Adam(cvae.parameters(), lr=opts.lr, betas=(opts.b1, opts.b2), weight_decay=opts.weight_decay) optimizer_dis = torch.optim.Adam(dis.parameters(), lr=opts.lr, betas=(opts.b1, opts.b2), weight_decay=opts.weight_decay) optimizer_classifier = torch.optim.Adam(classifier.parameters(), lr=opts.lr, betas=(opts.b1, opts.b2), weight_decay=opts.weight_decay) i = 1 while os.path.isdir('./ex/' + str(i)): i += 1 os.mkdir('./ex/' + str(i)) output_path = './ex/' + str(i) losses = {
model_G.cuda() model_D.cuda() print('cuda is available!') else: print('cuda is not available') # パラメータ設定 # params_G = optim.Adam(model_G.parameters(), # lr=0.0002, betas=(0.5, 0.999)) # params_D = optim.Adam(model_D.parameters(), # lr=0.0002, betas=(0.5, 0.999)) params_G = optim.Adam(model_G.parameters(), lr=0.01) params_D = optim.Adam(model_D.parameters(), lr=0.01) # 潜在特徴100次元ベクトルz nz = 100 # ロスを計算するときのラベル変数 if cuda: ones = torch.ones(batch_size).cuda() # 正例 1 zeros = torch.zeros(batch_size).cuda() # 負例 0 loss_f = nn.BCEWithLogitsLoss() # 途中結果の確認用の潜在特徴z check_z = torch.randn(batch_size, nz, 1, 1).cuda()
def train(): # Random Seed manual_seed = random.randint(1, 10000) print('Random Seed: ', manual_seed) random.seed(manual_seed) torch.manual_seed(manual_seed) # Parameter dataroot = 'E:\\datasets\\ukiyoe-1024' workers = 2 batch_size = 128 image_size = 64 nz = 100 num_epochs = 100 lr = 0.0002 beta1 = 0.5 ngpu = 1 # create dataset and dataloader dataset = dset.ImageFolder(root=dataroot, transform=transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers) # define device device = torch.device('cuda:0' if ( torch.cuda.is_available() and ngpu > 0) else 'cpu') print('device: ', device) netG = Generator(ngpu).to(device) netG.apply(weight_init) netD = Discriminator(ngpu).to(device) netD.apply(weight_init) criterion = nn.BCELoss() fixed_noise = torch.randn(64, nz, 1, 1, device=device) real_label = 1. fake_label = 0. optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999)) optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999)) # training loop img_list = [] G_losses = [] D_losses = [] iters = 0 print('Starting Training Loop.') for epoch in range(num_epochs): for i, data in enumerate(dataloader, 0): # Update D network netD.zero_grad() real_cpu = data[0].to(device) b_size = real_cpu.size(0) label = torch.full((b_size, ), real_label, dtype=torch.float, device=device) output = netD(real_cpu).view(-1) errD_real = criterion(output, label) errD_real.backward() D_x = output.mean().item() noise = torch.randn(b_size, nz, 1, 1, device=device) fake = netG(noise) label.fill_(fake_label) output = netD(fake.detach()).view(-1) errD_fake = criterion(output, label) errD_fake.backward() D_G_z1 = output.mean().item() errD = errD_real + errD_fake optimizerD.step() # Update G network netG.zero_grad() label.fill_(real_label) output = netD(fake).view(-1) errG = criterion(output, label) errG.backward() D_G_z2 = output.mean().item() optimizerG.step() # Output training stats if i % 50 == 0: print( '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch + 1, num_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) # Save Losses for plotting later G_losses.append(errG.item()) D_losses.append(errD.item()) # Check how the generator is doing by saving G's output on fixed_noise if (iters % 100 == 0) or ((epoch == num_epochs - 1) and (i == len(dataloader) - 1)): with torch.no_grad(): fake = netG(fixed_noise).detach().cpu() img_list.append( vutils.make_grid(fake, padding=2, normalize=True)) iters += 1 save_g_images(epoch, img_list) model_path_G = '.\\output\\model\\generator.pth' model_path_D = '.\\output\\model\\discriminator.pth' torch.save(netG, model_path_G) torch.save(netD, model_path_D) print('Finish training.')
disc.apply(init_weights) disc.to(device) #initialize the gan,content,perceptual and brightness loss gan_criterion = nn.BCELoss().to(device) content_criterion = nn.L1Loss().to(device) perceptual_criterion = nn.MSELoss().to(device) brightness_criterion = nn.L1Loss().to(device) #set the feature extractor to evaluation mode as it will be used only to calculate perceptual loss feature_extractor = FeatureExtractor() feature_extractor.eval() feature_extractor.to(device) #initialize the optimizers for Generator and Discriminator optimizerD = optim.Adam(disc.parameters(), lr=0.0003, betas=(0.5, 0.999)) optimizerG = optim.Adam(gen.parameters(), lr=0.0001, betas=(0.5, 0.999)) alpha = 0.5 beta = 1.8 gamma = 1.97 delta = 0.069 resume_epoch = 0 for e in range(resume_epoch, epochs): for i, data in enumerate(train_loader): hazy_images, clear_images = data #to prevent accumulation of gradients
try: if not os.path.exists(directory): os.makedirs(directory) print("Create directory: " + directory) except OSError: print('Error: Creating directory. ' + directory) torch.save(model.state_dict(), "ignore/weights/%s/%s_s.pth" % (dataset, tb, mode)) # init optimizer and scheduler print(" " * 75, "\r", "Loading optimizer...", end="\r") optimizer_G = torch.optim.SGD(itertools.chain(G_model_AtoB.parameters(), G_model_BtoA.parameters()), lr=args.lr) #, betas=(0.5, 0.999)) optimizer_D_A = torch.optim.SGD(D_model_A.parameters(), lr=args.lr) #, betas=(0.5, 0.999)) optimizer_D_B = torch.optim.SGD(D_model_B.parameters(), lr=args.lr) #, betas=(0.5, 0.999)) lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR( args.epochs, 0, 100).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR( args.epochs, 0, 100).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR( args.epochs, 0, 100).step)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") generator = Generator(512, 512).to(device) discriminator = Discriminator().to(device) g_optimizer = torch.optim.Adam( [{ 'params': generator.generator_mapping.parameters(), 'lr': 0.001 * 0.01 }, { 'params': generator.generator_synth.parameters() }], lr=0.001, betas=(0., 0.999)) d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001, betas=(0., 0.99)) ############ summ_counter = 0 mean_losses = np.zeros(5) batch_sizes = [256, 128, 64, 32, 16, 8] epoch_sizes = [2, 4, 4, 8, 8, 16] latent_const = torch.from_numpy(np.load('randn.npy')).float().to(device) transform = transforms.Compose([ transforms.CenterCrop([178, 178]), transforms.Resize([128, 128]), transforms.RandomHorizontalFlip(0.5),
def main(): ## load std models # policy_log_std = torch.load('./model_pkl/policy_net_action_std_model_1.pkl') # transition_log_std = torch.load('./model_pkl/transition_net_state_std_model_1.pkl') # load expert data print(args.data_set_path) dataset = ExpertDataSet(args.data_set_path) data_loader = data.DataLoader(dataset=dataset, batch_size=args.expert_batch_size, shuffle=True, num_workers=0) # define actor/critic/discriminator net and optimizer policy = Policy(onehot_action_sections, onehot_state_sections, state_0=dataset.state) value = Value() discriminator = Discriminator() optimizer_policy = torch.optim.Adam(policy.parameters(), lr=args.policy_lr) optimizer_value = torch.optim.Adam(value.parameters(), lr=args.value_lr) optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=args.discrim_lr) discriminator_criterion = nn.BCELoss() if write_scalar: writer = SummaryWriter(log_dir='runs/' + model_name) # load net models if load_model: discriminator.load_state_dict( torch.load('./model_pkl/Discriminator_model_' + model_name + '.pkl')) policy.transition_net.load_state_dict( torch.load('./model_pkl/Transition_model_' + model_name + '.pkl')) policy.policy_net.load_state_dict( torch.load('./model_pkl/Policy_model_' + model_name + '.pkl')) value.load_state_dict( torch.load('./model_pkl/Value_model_' + model_name + '.pkl')) policy.policy_net_action_std = torch.load( './model_pkl/Policy_net_action_std_model_' + model_name + '.pkl') policy.transition_net_state_std = torch.load( './model_pkl/Transition_net_state_std_model_' + model_name + '.pkl') print('############# start training ##############') # update discriminator num = 0 for ep in tqdm(range(args.training_epochs)): # collect data from environment for ppo update policy.train() value.train() discriminator.train() start_time = time.time() memory, n_trajs = policy.collect_samples( batch_size=args.sample_batch_size) # print('sample_data_time:{}'.format(time.time()-start_time)) batch = memory.sample() onehot_state = torch.cat(batch.onehot_state, dim=1).reshape( n_trajs * args.sample_traj_length, -1).detach() multihot_state = torch.cat(batch.multihot_state, dim=1).reshape( n_trajs * args.sample_traj_length, -1).detach() continuous_state = torch.cat(batch.continuous_state, dim=1).reshape( n_trajs * args.sample_traj_length, -1).detach() onehot_action = torch.cat(batch.onehot_action, dim=1).reshape( n_trajs * args.sample_traj_length, -1).detach() multihot_action = torch.cat(batch.multihot_action, dim=1).reshape( n_trajs * args.sample_traj_length, -1).detach() continuous_action = torch.cat(batch.continuous_action, dim=1).reshape( n_trajs * args.sample_traj_length, -1).detach() next_onehot_state = torch.cat(batch.next_onehot_state, dim=1).reshape( n_trajs * args.sample_traj_length, -1).detach() next_multihot_state = torch.cat(batch.next_multihot_state, dim=1).reshape( n_trajs * args.sample_traj_length, -1).detach() next_continuous_state = torch.cat( batch.next_continuous_state, dim=1).reshape(n_trajs * args.sample_traj_length, -1).detach() old_log_prob = torch.cat(batch.old_log_prob, dim=1).reshape( n_trajs * args.sample_traj_length, -1).detach() mask = torch.cat(batch.mask, dim=1).reshape(n_trajs * args.sample_traj_length, -1).detach() gen_state = torch.cat((onehot_state, multihot_state, continuous_state), dim=-1) gen_action = torch.cat( (onehot_action, multihot_action, continuous_action), dim=-1) if ep % 1 == 0: # if (d_slow_flag and ep % 50 == 0) or (not d_slow_flag and ep % 1 == 0): d_loss = torch.empty(0, device=device) p_loss = torch.empty(0, device=device) v_loss = torch.empty(0, device=device) gen_r = torch.empty(0, device=device) expert_r = torch.empty(0, device=device) for expert_state_batch, expert_action_batch in data_loader: noise1 = torch.normal(0, args.noise_std, size=gen_state.shape, device=device) noise2 = torch.normal(0, args.noise_std, size=gen_action.shape, device=device) noise3 = torch.normal(0, args.noise_std, size=expert_state_batch.shape, device=device) noise4 = torch.normal(0, args.noise_std, size=expert_action_batch.shape, device=device) gen_r = discriminator(gen_state + noise1, gen_action + noise2) expert_r = discriminator( expert_state_batch.to(device) + noise3, expert_action_batch.to(device) + noise4) # gen_r = discriminator(gen_state, gen_action) # expert_r = discriminator(expert_state_batch.to(device), expert_action_batch.to(device)) optimizer_discriminator.zero_grad() d_loss = discriminator_criterion(gen_r, torch.zeros(gen_r.shape, device=device)) + \ discriminator_criterion(expert_r,torch.ones(expert_r.shape, device=device)) variance = 0.5 * torch.var(gen_r.to(device)) + 0.5 * torch.var( expert_r.to(device)) total_d_loss = d_loss - 10 * variance d_loss.backward() # total_d_loss.backward() optimizer_discriminator.step() if write_scalar: writer.add_scalar('d_loss', d_loss, ep) writer.add_scalar('total_d_loss', total_d_loss, ep) writer.add_scalar('variance', 10 * variance, ep) if ep % 1 == 0: # update PPO noise1 = torch.normal(0, args.noise_std, size=gen_state.shape, device=device) noise2 = torch.normal(0, args.noise_std, size=gen_action.shape, device=device) gen_r = discriminator(gen_state + noise1, gen_action + noise2) #if gen_r.mean().item() < 0.1: # d_stop = True #if d_stop and gen_r.mean() optimize_iter_num = int( math.ceil(onehot_state.shape[0] / args.ppo_mini_batch_size)) # gen_r = -(1 - gen_r + 1e-10).log() for ppo_ep in range(args.ppo_optim_epoch): for i in range(optimize_iter_num): num += 1 index = slice( i * args.ppo_mini_batch_size, min((i + 1) * args.ppo_mini_batch_size, onehot_state.shape[0])) onehot_state_batch, multihot_state_batch, continuous_state_batch, onehot_action_batch, multihot_action_batch, continuous_action_batch, \ old_log_prob_batch, mask_batch, next_onehot_state_batch, next_multihot_state_batch, next_continuous_state_batch, gen_r_batch = \ onehot_state[index], multihot_state[index], continuous_state[index], onehot_action[index], multihot_action[index], continuous_action[index], \ old_log_prob[index], mask[index], next_onehot_state[index], next_multihot_state[index], next_continuous_state[index], gen_r[ index] v_loss, p_loss = ppo_step( policy, value, optimizer_policy, optimizer_value, onehot_state_batch, multihot_state_batch, continuous_state_batch, onehot_action_batch, multihot_action_batch, continuous_action_batch, next_onehot_state_batch, next_multihot_state_batch, next_continuous_state_batch, gen_r_batch, old_log_prob_batch, mask_batch, args.ppo_clip_epsilon) if write_scalar: writer.add_scalar('p_loss', p_loss, ep) writer.add_scalar('v_loss', v_loss, ep) policy.eval() value.eval() discriminator.eval() noise1 = torch.normal(0, args.noise_std, size=gen_state.shape, device=device) noise2 = torch.normal(0, args.noise_std, size=gen_action.shape, device=device) gen_r = discriminator(gen_state + noise1, gen_action + noise2) expert_r = discriminator( expert_state_batch.to(device) + noise3, expert_action_batch.to(device) + noise4) gen_r_noise = gen_r.mean().item() expert_r_noise = expert_r.mean().item() gen_r = discriminator(gen_state, gen_action) expert_r = discriminator(expert_state_batch.to(device), expert_action_batch.to(device)) if write_scalar: writer.add_scalar('gen_r', gen_r.mean(), ep) writer.add_scalar('expert_r', expert_r.mean(), ep) writer.add_scalar('gen_r_noise', gen_r_noise, ep) writer.add_scalar('expert_r_noise', expert_r_noise, ep) print('#' * 5 + 'training episode:{}'.format(ep) + '#' * 5) print('gen_r_noise', gen_r_noise) print('expert_r_noise', expert_r_noise) print('gen_r:', gen_r.mean().item()) print('expert_r:', expert_r.mean().item()) print('d_loss', d_loss.item()) # save models if model_name is not None: torch.save( discriminator.state_dict(), './model_pkl/Discriminator_model_' + model_name + '.pkl') torch.save(policy.transition_net.state_dict(), './model_pkl/Transition_model_' + model_name + '.pkl') torch.save(policy.policy_net.state_dict(), './model_pkl/Policy_model_' + model_name + '.pkl') torch.save( policy.policy_net_action_std, './model_pkl/Policy_net_action_std_model_' + model_name + '.pkl') torch.save( policy.transition_net_state_std, './model_pkl/Transition_net_state_std_model_' + model_name + '.pkl') torch.save(value.state_dict(), './model_pkl/Value_model_' + model_name + '.pkl') memory.clear_memory()
def main(data_dir): # 0) Tensoboard Writer. writer = SummaryWriter(FLAGS['summary_path']) origin_img, uv_map_gt, uv_map_predicted = None, None, None if not os.path.exists(FLAGS['images']): os.mkdir(FLAGS['images']) # 1) Create Dataset of 300_WLP & Dataloader. wlp300 = PRNetDataset(root_dir=data_dir, transform=transforms.Compose([ ToTensor(), ToResize((416, 416)), ToNormalize(FLAGS["normalize_mean"], FLAGS["normalize_std"]) ])) wlp300_dataloader = DataLoader(dataset=wlp300, batch_size=FLAGS['batch_size'], shuffle=True, num_workers=1) # 2) Intermediate Processing. transform_img = transforms.Compose([ transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"]) ]) # 3) Create PRNet model. start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch'] model = ResFCN256(resolution_input=416, resolution_output=416, channel=3, size=16) discriminator = Discriminator() # Load the pre-trained weight if FLAGS['resume'] != "" and os.path.exists( os.path.join(FLAGS['pretrained'], FLAGS['resume'])): state = torch.load(os.path.join(FLAGS['pretrained'], FLAGS['resume'])) model.load_state_dict(state['prnet']) start_epoch = state['start_epoch'] INFO("Load the pre-trained weight! Start from Epoch", start_epoch) else: start_epoch = 0 INFO( "Pre-trained weight cannot load successfully, train from scratch!") if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model.to(FLAGS["device"]) discriminator.to(FLAGS["device"]) optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS["lr"], betas=(0.5, 0.999)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=FLAGS["lr"]) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"]) loss = WeightMaskLoss(mask_path=FLAGS["mask_path"]) bce_loss = torch.nn.BCEWithLogitsLoss() bce_loss.to(FLAGS["device"]) #Loss function for adversarial for ep in range(start_epoch, target_epoch): bar = tqdm(wlp300_dataloader) loss_list_G, stat_list = [], [] loss_list_D = [] for i, sample in enumerate(bar): uv_map, origin = sample['uv_map'].to( FLAGS['device']), sample['origin'].to(FLAGS['device']) # Inference. optimizer.zero_grad() uv_map_result = model(origin) # Update D optimizer_D.zero_grad() fake_detach = uv_map_result.detach() d_fake = discriminator(fake_detach) d_real = discriminator(uv_map) retain_graph = False if FLAGS['gan_type'] == 'GAN': loss_d = bce_loss(d_real, d_fake) elif FLAGS['gan_type'].find('WGAN') >= 0: loss_d = (d_fake - d_real).mean() if FLAGS['gan_type'].find('GP') >= 0: epsilon = torch.rand(fake_detach.shape[0]).view( -1, 1, 1, 1) epsilon = epsilon.to(fake_detach.device) hat = fake_detach.mul(1 - epsilon) + uv_map.mul(epsilon) hat.requires_grad = True d_hat = discriminator(hat) gradients = torch.autograd.grad(outputs=d_hat.sum(), inputs=hat, retain_graph=True, create_graph=True, only_inputs=True)[0] gradients = gradients.view(gradients.size(0), -1) gradient_norm = gradients.norm(2, dim=1) gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() loss_d += gradient_penalty # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks elif FLAGS['gan_type'] == 'RGAN': better_real = d_real - d_fake.mean(dim=0, keepdim=True) better_fake = d_fake - d_real.mean(dim=0, keepdim=True) loss_d = bce_loss(better_real, better_fake) retain_graph = True if discriminator.training: loss_list_D.append(loss_d.item()) loss_d.backward(retain_graph=retain_graph) optimizer_D.step() if 'WGAN' in FLAGS['gan_type']: for p in discriminator.parameters(): p.data.clamp_(-1, 1) # Update G d_fake_bp = discriminator( uv_map_result) # for backpropagation, use fake as it is if FLAGS['gan_type'] == 'GAN': label_real = torch.ones_like(d_fake_bp) loss_g = bce_loss(d_fake_bp, label_real) elif FLAGS['gan_type'].find('WGAN') >= 0: loss_g = -d_fake_bp.mean() elif FLAGS['gan_type'] == 'RGAN': better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True) better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True) loss_g = bce_loss(better_fake, better_real) loss_g.backward() loss_list_G.append(loss_g.item()) optimizer.step() stat_logit = stat_loss(uv_map_result, uv_map) stat_list.append(stat_logit.item()) #bar.set_description(" {} [Loss(Paper)] {} [Loss(D)] {} [SSIM({})] {}".format(ep, loss_list_G[-1], loss_list_D[-1],FLAGS["gauss_kernel"], stat_list[-1])) # Record Training information in Tensorboard. """ if origin_img is None and uv_map_gt is None: origin_img, uv_map_gt = origin, uv_map uv_map_predicted = uv_map_result writer.add_scalar("Original Loss", loss_list_G[-1], FLAGS["summary_step"]) writer.add_scalar("D Loss", loss_list_D[-1], FLAGS["summary_step"]) writer.add_scalar("SSIM Loss", stat_list[-1], FLAGS["summary_step"]) grid_1, grid_2, grid_3 = make_grid(origin_img, normalize=True), make_grid(uv_map_gt), make_grid(uv_map_predicted) writer.add_image('original', grid_1, FLAGS["summary_step"]) writer.add_image('gt_uv_map', grid_2, FLAGS["summary_step"]) writer.add_image('predicted_uv_map', grid_3, FLAGS["summary_step"]) writer.add_graph(model, uv_map) """ if ep % FLAGS["save_interval"] == 0: with torch.no_grad(): print(" {} [Loss(Paper)] {} [Loss(D)] {} [SSIM({})] {}".format( ep, loss_list_G[-1], loss_list_D[-1], FLAGS["gauss_kernel"], stat_list[-1])) origin = cv2.imread("./test_data/obama_origin.jpg") gt_uv_map = np.load("./test_data/test_obama.npy") origin, gt_uv_map = test_data_preprocess( origin), test_data_preprocess(gt_uv_map) origin, gt_uv_map = transform_img(origin), transform_img( gt_uv_map) origin_in = origin.unsqueeze_(0).cuda() pred_uv_map = model(origin_in).detach().cpu() save_image( [origin.cpu(), gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map], os.path.join(FLAGS['images'], str(ep) + '.png'), nrow=1, normalize=True) # Save model print("Save model") state = { 'prnet': model.state_dict(), 'Loss': loss_list_G, 'start_epoch': ep, 'Loss_D': loss_list_D, } torch.save(state, os.path.join(FLAGS['images'], '{}.pth'.format(ep))) scheduler.step() writer.close()
class Trainer(nn.Module): def __init__(self, model_dir, g_optimizer, d_optimizer, lr, warmup, max_iters): super().__init__() self.model_dir = model_dir if not os.path.exists(f'checkpoints/{model_dir}'): os.makedirs(f'checkpoints/{model_dir}') self.logs_dir = f'checkpoints/{model_dir}/logs' if not os.path.exists(self.logs_dir): os.makedirs(self.logs_dir) self.writer = SummaryWriter(self.logs_dir) self.arcface = ArcFaceNet(50, 0.6, 'ir_se').cuda() self.arcface.eval() self.arcface.load_state_dict(torch.load( 'checkpoints/model_ir_se50.pth', map_location='cuda'), strict=False) self.mobiface = MobileFaceNet(512).cuda() self.mobiface.eval() self.mobiface.load_state_dict(torch.load( 'checkpoints/mobilefacenet.pth', map_location='cuda'), strict=False) self.generator = Generator().cuda() self.discriminator = Discriminator().cuda() self.adversarial_weight = 1 self.src_id_weight = 5 self.tgt_id_weight = 1 self.attributes_weight = 10 self.reconstruction_weight = 10 self.lr = lr self.warmup = warmup self.g_optimizer = g_optimizer(self.generator.parameters(), lr=lr, betas=(0, 0.999)) self.d_optimizer = d_optimizer(self.discriminator.parameters(), lr=lr, betas=(0, 0.999)) self.generator, self.g_optimizer = amp.initialize(self.generator, self.g_optimizer, opt_level="O1") self.discriminator, self.d_optimizer = amp.initialize( self.discriminator, self.d_optimizer, opt_level="O1") self._iter = nn.Parameter(torch.tensor(1), requires_grad=False) self.max_iters = max_iters if torch.cuda.is_available(): self.cuda() @property def iter(self): return self._iter.item() @property def device(self): return next(self.parameters()).device def adapt(self, args): device = self.device return [arg.to(device) for arg in args] def train_loop(self, dataloaders, eval_every, generate_every, save_every): for batch in tqdm(dataloaders['train']): torch.Tensor.add_(self._iter, 1) # generator step # if self.iter % 2 == 0: # self.adjust_lr(self.g_optimizer) g_losses = self.g_step(self.adapt(batch)) g_stats = self.get_opt_stats(self.g_optimizer, type='generator') self.write_logs(losses=g_losses, stats=g_stats, type='generator') # #discriminator step # if self.iter % 2 == 1: # self.adjust_lr(self.d_optimizer) d_losses = self.d_step(self.adapt(batch)) d_stats = self.get_opt_stats(self.d_optimizer, type='discriminator') self.write_logs(losses=d_losses, stats=d_stats, type='discriminator') if self.iter % eval_every == 0: discriminator_acc = self.evaluate_discriminator_accuracy( dataloaders['val']) identification_acc = self.evaluate_identification_similarity( dataloaders['val']) metrics = {**discriminator_acc, **identification_acc} self.write_logs(metrics=metrics) if self.iter % generate_every == 0: self.generate(*self.adapt(batch)) if self.iter % save_every == 0: self.save_discriminator() self.save_generator() def g_step(self, batch): self.generator.train() self.g_optimizer.zero_grad() L_adv, L_src_id, L_tgt_id, L_attr, L_rec, L_generator = self.g_loss( *batch) with amp.scale_loss(L_generator, self.g_optimizer) as scaled_loss: scaled_loss.backward() self.g_optimizer.step() losses = { 'adv': L_adv.item(), 'src_id': L_src_id.item(), 'tgt_id': L_tgt_id.item(), 'attributes': L_attr.item(), 'reconstruction': L_rec.item(), 'total_loss': L_generator.item() } return losses def d_step(self, batch): self.discriminator.train() self.d_optimizer.zero_grad() L_fake, L_real, L_discriminator = self.d_loss(*batch) with amp.scale_loss(L_discriminator, self.d_optimizer) as scaled_loss: scaled_loss.backward() self.d_optimizer.step() losses = { 'hinge_fake': L_fake.item(), 'hinge_real': L_real.item(), 'total_loss': L_discriminator.item() } return losses def g_loss(self, Xs, Xt, same_person): with torch.no_grad(): src_embed = self.arcface( F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True)) tgt_embed = self.arcface( F.interpolate(Xt[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True)) Y_hat, Xt_attr = self.generator(Xt, src_embed, return_attributes=True) Di = self.discriminator(Y_hat) L_adv = 0 for di in Di: L_adv += hinge_loss(di[0], True) fake_embed = self.arcface( F.interpolate(Y_hat[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True)) L_src_id = ( 1 - torch.cosine_similarity(src_embed, fake_embed, dim=1)).mean() L_tgt_id = ( 1 - torch.cosine_similarity(tgt_embed, fake_embed, dim=1)).mean() batch_size = Xs.shape[0] Y_hat_attr = self.generator.get_attr(Y_hat) L_attr = 0 for i in range(len(Xt_attr)): L_attr += torch.mean(torch.pow(Xt_attr[i] - Y_hat_attr[i], 2).reshape(batch_size, -1), dim=1).mean() L_attr /= 2.0 L_rec = torch.sum( 0.5 * torch.mean(torch.pow(Y_hat - Xt, 2).reshape(batch_size, -1), dim=1) * same_person) / (same_person.sum() + 1e-6) L_generator = (self.adversarial_weight * L_adv) + (self.src_id_weight * L_src_id) + ( self.tgt_id_weight * L_tgt_id) + (self.attributes_weight * L_attr) + ( self.reconstruction_weight * L_rec) return L_adv, L_src_id, L_tgt_id, L_attr, L_rec, L_generator def d_loss(self, Xs, Xt, same_person): with torch.no_grad(): src_embed = self.arcface( F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True)) Y_hat = self.generator(Xt, src_embed, return_attributes=False) fake_D = self.discriminator(Y_hat.detach()) L_fake = 0 for di in fake_D: L_fake += hinge_loss(di[0], False) real_D = self.discriminator(Xs) L_real = 0 for di in real_D: L_real += hinge_loss(di[0], True) L_discriminator = 0.5 * (L_real + L_fake) return L_fake, L_real, L_discriminator def evaluate_discriminator_accuracy(self, val_dataloader): real_acc = 0 fake_acc = 0 self.generator.eval() self.discriminator.eval() for batch in tqdm(val_dataloader): Xs, Xt, _ = self.adapt(batch) with torch.no_grad(): embed = self.arcface( F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True)) Y_hat = self.generator(Xt, embed, return_attributes=False) fake_D = self.discriminator(Y_hat) real_D = self.discriminator(Xs) fake_multiscale_acc = 0 for di in fake_D: fake_multiscale_acc += torch.mean((di[0] < 0).float()) fake_acc += fake_multiscale_acc / len(fake_D) real_multiscale_acc = 0 for di in real_D: real_multiscale_acc += torch.mean((di[0] > 0).float()) real_acc += real_multiscale_acc / len(real_D) self.generator.train() self.discriminator.train() metrics = { 'fake_acc': 100 * (fake_acc / len(val_dataloader)).item(), 'real_acc': 100 * (real_acc / len(val_dataloader)).item() } return metrics def evaluate_identification_similarity(self, val_dataloader): src_id_sim = 0 tgt_id_sim = 0 self.generator.eval() for batch in tqdm(val_dataloader): Xs, Xt, _ = self.adapt(batch) with torch.no_grad(): src_embed = self.arcface( F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True)) Y_hat = self.generator(Xt, src_embed, return_attributes=False) src_embed = self.mobiface( F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True)) tgt_embed = self.mobiface( F.interpolate(Xt[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True)) fake_embed = self.mobiface( F.interpolate(Y_hat[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True)) src_id_sim += (torch.cosine_similarity(src_embed, fake_embed, dim=1)).float().mean() tgt_id_sim += (torch.cosine_similarity(tgt_embed, fake_embed, dim=1)).float().mean() self.generator.train() metrics = { 'src_similarity': 100 * (src_id_sim / len(val_dataloader)).item(), 'tgt_similarity': 100 * (tgt_id_sim / len(val_dataloader)).item() } return metrics def generate(self, Xs, Xt, same_person): def get_grid_image(X): X = X[:8] X = torchvision.utils.make_grid(X.detach().cpu(), nrow=X.shape[0]) X = (X * 0.5 + 0.5) * 255 return X def make_image(Xs, Xt, Y_hat): Xs = get_grid_image(Xs) Xt = get_grid_image(Xt) Y_hat = get_grid_image(Y_hat) return torch.cat((Xs, Xt, Y_hat), dim=1).numpy() with torch.no_grad(): embed = self.arcface( F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112], mode='bilinear', align_corners=True)) self.generator.eval() Y_hat = self.generator(Xt, embed, return_attributes=False) self.generator.train() image = make_image(Xs, Xt, Y_hat) if not os.path.exists(f'results/{self.model_dir}'): os.makedirs(f'results/{self.model_dir}') cv2.imwrite(f'results/{self.model_dir}/{self.iter}.jpg', image.transpose([1, 2, 0])) def get_opt_stats(self, optimizer, type=''): stats = {f'{type}_lr': optimizer.param_groups[0]['lr']} return stats def adjust_lr(self, optimizer): if self.iter <= self.warmup: lr = self.lr * self.iter / self.warmup else: lr = self.lr * (1 + cos(pi * (self.iter - self.warmup) / (self.max_iters - self.warmup))) / 2 for group in optimizer.param_groups: group['lr'] = lr return lr def write_logs(self, losses=None, metrics=None, stats=None, type='loss'): if losses: for name, value in losses.items(): self.writer.add_scalar(f'{type}/{name}', value, self.iter) if metrics: for name, value in metrics.items(): self.writer.add_scalar(f'metric/{name}', value, self.iter) if stats: for name, value in stats.items(): self.writer.add_scalar(f'stats/{name}', value, self.iter) def save_generator(self, max_checkpoints=100): checkpoints = glob.glob(f'{self.model_dir}/*.pt') if len(checkpoints) > max_checkpoints: os.remove(checkpoints[-1]) with open(f'checkpoints/{self.model_dir}/generator_{self.iter}.pt', 'wb') as f: torch.save(self.generator.state_dict(), f) def save_discriminator(self, max_checkpoints=100): checkpoints = glob.glob(f'{self.model_dir}/*.pt') if len(checkpoints) > max_checkpoints: os.remove(checkpoints[-1]) with open(f'checkpoints/{self.model_dir}/discriminator_{self.iter}.pt', 'wb') as f: torch.save(self.discriminator.state_dict(), f) def load_discriminator(self, path, load_last=True): if load_last: try: checkpoints = glob.glob(f'{path}/discriminator*.pt') path = max(checkpoints, key=os.path.getctime) except (ValueError): print(f'Directory is empty: {path}') try: self.discriminator.load_state_dict(torch.load(path)) self.cuda() except (FileNotFoundError): print(f'No such file: {path}') def load_generator(self, path, load_last=True): if load_last: try: checkpoints = glob.glob(f'{path}/generator*.pt') path = max(checkpoints, key=os.path.getctime) except (ValueError): print(f'Directory is empty: {path}') try: self.generator.load_state_dict(torch.load(path)) iter_str = ''.join(filter(lambda x: x.isdigit(), path)) self._iter = nn.Parameter(torch.tensor(int(iter_str)), requires_grad=False) self.cuda() except (FileNotFoundError): print(f'No such file: {path}')
def run(args): # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Define model logger.info(f"Loading Model of {args.model_name}...") with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) hp.lambda_stft = config["lamda_stft"] hp.use_feature_map_loss = config["use_feature_map_loss"] if args.model_name == "melgan": model = MelGANGenerator( in_channels=config["in_channels"], out_channels=config["out_channels"], kernel_size=config["kernel_size"], channels=config["channels"], upsample_scales=config["upsample_scales"], stack_kernel_size=config["stack_kernel_size"], stacks=config["stacks"], use_weight_norm=config["use_weight_norm"], use_causal_conv=config["use_causal_conv"]).to(device) elif args.model_name == "hifigan": model = HiFiGANGenerator( resblock_kernel_sizes=config["resblock_kernel_sizes"], upsample_rates=config["upsample_rates"], upsample_initial_channel=config["upsample_initial_channel"], resblock_type=config["resblock_type"], upsample_kernel_sizes=config["upsample_kernel_sizes"], resblock_dilation_sizes=config["resblock_dilation_sizes"], transposedconv=config["transposedconv"], bias=config["bias"]).to(device) elif args.model_name == "multiband-hifigan": model = MultiBandHiFiGANGenerator( resblock_kernel_sizes=config["resblock_kernel_sizes"], upsample_rates=config["upsample_rates"], upsample_initial_channel=config["upsample_initial_channel"], resblock_type=config["resblock_type"], upsample_kernel_sizes=config["upsample_kernel_sizes"], resblock_dilation_sizes=config["resblock_dilation_sizes"], transposedconv=config["transposedconv"], bias=config["bias"]).to(device) elif args.model_name == "basis-melgan": basis_signal_weight = np.load( os.path.join("Basis-MelGAN-dataset", "basis_signal_weight.npy")) basis_signal_weight = torch.from_numpy(basis_signal_weight) model = BasisMelGANGenerator( basis_signal_weight=basis_signal_weight, L=config["L"], in_channels=config["in_channels"], out_channels=config["out_channels"], kernel_size=config["kernel_size"], channels=config["channels"], upsample_scales=config["upsample_scales"], stack_kernel_size=config["stack_kernel_size"], stacks=config["stacks"], use_weight_norm=config["use_weight_norm"], use_causal_conv=config["use_causal_conv"], transposedconv=config["transposedconv"]).to(device) else: raise Exception("no model find!") pqmf = None if config["multiband"] == True: logger.info("Define PQMF") pqmf = PQMF().to(device) logger.info(f"model is {str(model)}") discriminator = Discriminator().to(device) logger.info("Model Has Been Defined") num_param = get_param_num(model) logger.info(f'Number of TTS Parameters: {num_param}') # Optimizer and loss basis_signal_optimizer = None if not args.mixprecision: if args.model_name == "basis-melgan": optimizer = Adam(model.melgan.parameters(), lr=args.learning_rate, eps=1.0e-6, weight_decay=0.0) # freeze basis signal layer basis_signal_optimizer = Adam(model.basis_signal.parameters()) else: optimizer = Adam(model.parameters(), lr=args.learning_rate, eps=1.0e-6, weight_decay=0.0) discriminator_optimizer = Adam(discriminator.parameters(), lr=args.learning_rate_discriminator, eps=1.0e-6, weight_decay=0.0) else: if args.model_name == "basis-melgan": raise Exception("basis melgan don't support amp!") optimizer = apex.optimizers.FusedAdam(model.parameters(), lr=args.learning_rate) discriminator_optimizer = apex.optimizers.FusedAdam( discriminator.parameters(), lr=args.learning_rate_discriminator) model, optimizer = amp.initialize(model, optimizer, opt_level="O1", keep_batchnorm_fp32=None) discriminator, discriminator_optimizer = amp.initialize( discriminator, discriminator_optimizer, opt_level="O1") logger.info("Start mix precision training...") if args.use_scheduler: scheduler = CosineAnnealingLR(optimizer, T_max=2500, eta_min=args.learning_rate / 10.) discriminator_scheduler = CosineAnnealingLR( discriminator_optimizer, T_max=2500, eta_min=args.learning_rate_discriminator / 10.) else: scheduler = None discriminator_scheduler = None vocoder_loss = Loss().to(device) logger.info("Defined Optimizer and Loss Function.") # Load checkpoint if exists os.makedirs(hp.checkpoint_path, exist_ok=True) current_checkpoint_path = str(datetime.now()).replace(" ", "-").replace( ":", "-").replace(".", "-") current_checkpoint_path = os.path.join(hp.checkpoint_path, current_checkpoint_path) try: checkpoint = torch.load(os.path.join(args.checkpoint_path), map_location=torch.device(device)) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) if 'discriminator' in checkpoint: logger.info("loading discriminator") discriminator.load_state_dict(checkpoint['discriminator']) discriminator_optimizer.load_state_dict( checkpoint['discriminator_optimizer']) os.makedirs(current_checkpoint_path, exist_ok=True) if args.mixprecision: amp.load_state_dict(checkpoint['amp']) logger.info("\n---Model Restored at Step %d---\n" % args.restore_step) except: logger.info("\n---Start New Training---\n") os.makedirs(current_checkpoint_path, exist_ok=True) # Init logger os.makedirs(hp.logger_path, exist_ok=True) current_logger_path = str(datetime.now()).replace(" ", "-").replace( ":", "-").replace(".", "-") writer = SummaryWriter( os.path.join(hp.tensorboard_path, current_logger_path)) current_logger_path = os.path.join(hp.logger_path, current_logger_path) os.makedirs(current_logger_path, exist_ok=True) # Get buffer if args.model_name != "basis-melgan": logger.info("Load data to buffer") buffer = load_data_to_buffer(args.audio_index_path, args.mel_index_path, logger, feature_savepath="features_train.bin") logger.info("Load valid data to buffer") valid_buffer = load_data_to_buffer( args.audio_index_valid_path, args.mel_index_valid_path, logger, feature_savepath="features_valid.bin") # Get dataset if args.model_name == "basis-melgan": dataset = WeightDataset(args.audio_index_path, args.mel_index_path, config["L"]) valid_dataset = WeightDataset(args.audio_index_valid_path, args.mel_index_valid_path, config["L"]) else: dataset = BufferDataset(buffer) valid_dataset = BufferDataset(valid_buffer) # Get Training Loader training_loader = DataLoader(dataset, batch_size=hp.batch_expand_size * hp.batch_size, shuffle=True, collate_fn=collate_fn_tensor, drop_last=True, num_workers=4, prefetch_factor=2, pin_memory=True) logger.info(f"Length of training loader is {len(training_loader)}") total_step = hp.epochs * len(training_loader) * hp.batch_expand_size # Define Some Information time_list = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): for i, batchs in enumerate(training_loader): # real batch start here for j, db in enumerate(batchs): current_step = i * hp.batch_expand_size + j + args.restore_step + epoch * len( training_loader) * hp.batch_expand_size + 1 # Get Data clock_1_s = time.perf_counter() mel = db["mel"].float().to(device) wav = db["wav"].float().to(device) mel = mel.contiguous().transpose(1, 2) weight = None if "weight" in db: weight = db["weight"].float().to(device) clock_1_e = time.perf_counter() time_used_1 = round(clock_1_e - clock_1_s, 5) # Training clock_2_s = time.perf_counter() time_list = trainer( model, discriminator, optimizer, discriminator_optimizer, scheduler, discriminator_scheduler, vocoder_loss, mel, wav, epoch, current_step, total_step, time_list, Start, current_checkpoint_path, current_logger_path, writer, weight=weight, basis_signal_optimizer=basis_signal_optimizer, pqmf=pqmf, mixprecision=args.mixprecision) clock_2_e = time.perf_counter() time_used_2 = round(clock_2_e - clock_2_s, 5) if current_step % hp.valid_step == 0: logger.info("Start valid...") valid_loader = DataLoader( valid_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn_tensor_valid, num_workers=0) valid_loss_all = 0. for ii, valid_batch in enumerate(valid_loader): valid_mel = valid_batch["mel"].float().to(device) valid_mel = valid_mel.contiguous().transpose(1, 2) valid_wav = valid_batch["wav"].float().to(device) with torch.no_grad(): if args.model_name == "basis-melgan": valid_est_source, _ = model(valid_mel) else: valid_est_source = model(valid_mel) valid_stft_loss, _ = vocoder_loss(valid_est_source, valid_wav, pqmf=pqmf) valid_loss_all += valid_stft_loss.item() if ii == hp.valid_num: break writer.add_scalar('valid_stft_loss', valid_loss_all / float(hp.valid_num), global_step=current_step) writer.export_scalars_to_json(os.path.join("all_scalars.json")) writer.close() return
def main(args=args): dataset_base_path = path.join(args.base_path, "dataset", "celeba") image_base_path = path.join(dataset_base_path, "img_align_celeba") split_dataset_path = path.join(dataset_base_path, "Eval", "list_eval_partition.txt") with open(split_dataset_path, "r") as f: split_annotation = f.read().splitlines() # create the data name list for train,test and valid train_data_name_list = [] test_data_name_list = [] valid_data_name_list = [] for item in split_annotation: item = item.split(" ") if item[1] == '0': train_data_name_list.append(item[0]) elif item[1] == '1': valid_data_name_list.append(item[0]) else: test_data_name_list.append(item[0]) attribute_annotation_dict = None if args.need_label: attribute_annotation_path = path.join(dataset_base_path, "Anno", "list_attr_celeba.txt") with open(attribute_annotation_path, "r") as f: attribute_annotation = f.read().splitlines() attribute_annotation = attribute_annotation[2:] attribute_annotation_dict = {} for item in attribute_annotation: img_name, attribute = item.split(" ", 1) attribute = tuple([eval(attr) for attr in attribute.split(" ") if attr != ""]) assert len(attribute) == 40, "the attribute of item {} is not equal to 40".format(img_name) attribute_annotation_dict[img_name] = attribute discriminator = Discriminator(num_channel=args.num_channel, num_feature=args.dnf, data_parallel=args.data_parallel).cuda() generator = Generator(latent_dim=args.latent_dim, num_feature=args.gnf, num_channel=args.num_channel, data_parallel=args.data_parallel).cuda() input("Begin the {} time's training, the train dataset has {} images and the valid has {} images".format( args.train_time, len(train_data_name_list), len(valid_data_name_list))) d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) g_optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) d_scheduler = ExponentialLR(d_optimizer, gamma=args.decay_lr) g_scheduler = ExponentialLR(g_optimizer, gamma=args.decay_lr) writer_log_dir = "{}/DCGAN/runs/train_time:{}".format(args.base_path, args.train_time) # Here we implement the resume part if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) if not args.not_resume_arg: args = checkpoint['args'] args.start_epoch = checkpoint['epoch'] discriminator.load_state_dict(checkpoint["discriminator_state_dict"]) generator.load_state_dict(checkpoint["generator_state_dict"]) d_optimizer.load_state_dict(checkpoint['discriminator_optimizer']) g_optimizer.load_state_dict(checkpoint['generator_optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: raise FileNotFoundError("Checkpoint Resume File {} Not Found".format(args.resume)) else: if os.path.exists(writer_log_dir): flag = input("DCGAN train_time:{} will be removed, input yes to continue:".format( args.train_time)) if flag == "yes": shutil.rmtree(writer_log_dir, ignore_errors=True) writer = SummaryWriter(log_dir=writer_log_dir) # Here we just use the train dset in training train_dset = CelebADataset(base_path=image_base_path, data_name_list=train_data_name_list, image_size=args.image_size, label_dict=attribute_annotation_dict) train_dloader = DataLoader(dataset=train_dset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) criterion = nn.BCELoss() for epoch in range(args.start_epoch, args.epochs): train(train_dloader, generator, discriminator, g_optimizer, d_optimizer, criterion, writer, epoch) # adjust lr d_scheduler.step() g_scheduler.step() # save parameters save_checkpoint({ 'epoch': epoch + 1, 'args': args, "discriminator_state_dict": discriminator.state_dict(), "generator_state_dict": generator.state_dict(), 'discriminator_optimizer': d_optimizer.state_dict(), 'generator_optimizer': g_optimizer.state_dict() })
def main(): # define actor/critic/discriminator net and optimizer policy = Policy(discrete_action_sections, discrete_state_sections) value = Value() discriminator = Discriminator() optimizer_policy = torch.optim.Adam(policy.parameters(), lr=args.policy_lr) optimizer_value = torch.optim.Adam(value.parameters(), lr=args.value_lr) optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=args.discrim_lr) discriminator_criterion = nn.BCELoss() writer = SummaryWriter() # load expert data dataset = ExpertDataSet(args.expert_activities_data_path, args.expert_cost_data_path) data_loader = data.DataLoader(dataset=dataset, batch_size=args.expert_batch_size, shuffle=False, num_workers=1) # load models # discriminator.load_state_dict(torch.load('./model_pkl/Discriminator_model_3.pkl')) # policy.transition_net.load_state_dict(torch.load('./model_pkl/Transition_model_3.pkl')) # policy.policy_net.load_state_dict(torch.load('./model_pkl/Policy_model_3.pkl')) # value.load_state_dict(torch.load('./model_pkl/Value_model_3.pkl')) print('############# start training ##############') # update discriminator num = 0 for ep in tqdm(range(args.training_epochs)): # collect data from environment for ppo update start_time = time.time() memory = policy.collect_samples(args.ppo_buffer_size, size=10000) # print('sample_data_time:{}'.format(time.time()-start_time)) batch = memory.sample() continuous_state = torch.stack( batch.continuous_state).squeeze(1).detach() discrete_action = torch.stack( batch.discrete_action).squeeze(1).detach() continuous_action = torch.stack( batch.continuous_action).squeeze(1).detach() next_discrete_state = torch.stack( batch.next_discrete_state).squeeze(1).detach() next_continuous_state = torch.stack( batch.next_continuous_state).squeeze(1).detach() old_log_prob = torch.stack(batch.old_log_prob).detach() mask = torch.stack(batch.mask).squeeze(1).detach() discrete_state = torch.stack(batch.discrete_state).squeeze(1).detach() d_loss = torch.empty(0, device=device) p_loss = torch.empty(0, device=device) v_loss = torch.empty(0, device=device) gen_r = torch.empty(0, device=device) expert_r = torch.empty(0, device=device) for _ in range(1): for expert_state_batch, expert_action_batch in data_loader: gen_state = torch.cat((discrete_state, continuous_state), dim=-1) gen_action = torch.cat((discrete_action, continuous_action), dim=-1) gen_r = discriminator(gen_state, gen_action) expert_r = discriminator(expert_state_batch, expert_action_batch) optimizer_discriminator.zero_grad() d_loss = discriminator_criterion(gen_r, torch.zeros(gen_r.shape, device=device)) + \ discriminator_criterion(expert_r, torch.ones(expert_r.shape, device=device)) total_d_loss = d_loss - 10 * torch.var(gen_r.to(device)) d_loss.backward() # total_d_loss.backward() optimizer_discriminator.step() writer.add_scalar('d_loss', d_loss, ep) # writer.add_scalar('total_d_loss', total_d_loss, ep) writer.add_scalar('expert_r', expert_r.mean(), ep) # update PPO gen_r = discriminator( torch.cat((discrete_state, continuous_state), dim=-1), torch.cat((discrete_action, continuous_action), dim=-1)) optimize_iter_num = int( math.ceil(discrete_state.shape[0] / args.ppo_mini_batch_size)) for ppo_ep in range(args.ppo_optim_epoch): for i in range(optimize_iter_num): num += 1 index = slice( i * args.ppo_mini_batch_size, min((i + 1) * args.ppo_mini_batch_size, discrete_state.shape[0])) discrete_state_batch, continuous_state_batch, discrete_action_batch, continuous_action_batch, \ old_log_prob_batch, mask_batch, next_discrete_state_batch, next_continuous_state_batch, gen_r_batch = \ discrete_state[index], continuous_state[index], discrete_action[index], continuous_action[index], \ old_log_prob[index], mask[index], next_discrete_state[index], next_continuous_state[index], gen_r[ index] v_loss, p_loss = ppo_step( policy, value, optimizer_policy, optimizer_value, discrete_state_batch, continuous_state_batch, discrete_action_batch, continuous_action_batch, next_discrete_state_batch, next_continuous_state_batch, gen_r_batch, old_log_prob_batch, mask_batch, args.ppo_clip_epsilon) writer.add_scalar('p_loss', p_loss, num) writer.add_scalar('v_loss', v_loss, num) writer.add_scalar('gen_r', gen_r.mean(), num) print('#' * 5 + 'training episode:{}'.format(ep) + '#' * 5) print('d_loss', d_loss.item()) # print('p_loss', p_loss.item()) # print('v_loss', v_loss.item()) print('gen_r:', gen_r.mean().item()) print('expert_r:', expert_r.mean().item()) memory.clear_memory() # save models torch.save(discriminator.state_dict(), './model_pkl/Discriminator_model_4.pkl') torch.save(policy.transition_net.state_dict(), './model_pkl/Transition_model_4.pkl') torch.save(policy.policy_net.state_dict(), './model_pkl/Policy_model_4.pkl') torch.save(value.state_dict(), './model_pkl/Value_model_4.pkl')
class HiDDen(object): def __init__(self, config: HiDDenConfiguration, device: torch.device): self.enc_dec = EncoderDecoder(config).to(device) self.discr = Discriminator(config).to(device) self.opt_enc_dec = torch.optim.Adam(self.enc_dec.parameters()) self.opt_discr = torch.optim.Adam(self.discr.parameters()) self.config = config self.device = device self.bce_with_logits_loss = nn.BCEWithLogitsLoss().to(device) self.mse_loss = nn.MSELoss().to(device) self.cover_label = 1 self.encod_label = 0 def train_on_batch(self, batch: list): ''' Trains the network on a single batch consistring images and messages ''' images, messages = batch batch_size = images.shape[0] self.enc_dec.train() self.discr.train() with torch.enable_grad(): # ---------- Train the discriminator---------- self.opt_discr.zero_grad() # train on cover d_target_label_cover = torch.full((batch_size, 1), self.cover_label, device=self.device) d_target_label_encoded = torch.full((batch_size, 1), self.encod_label, device=self.device) g_target_label_encoded = torch.full((batch_size, 1), self.cover_label, device=self.device) d_on_cover = self.discr(images) d_loss_on_cover = self.bce_with_logits_loss( d_on_cover, d_target_label_cover) d_loss_on_cover.backward() # train on fake encoded_images, decoded_messages = self.enc_dec(images, messages) d_on_encoded = self.discr(encoded_images.detach()) d_loss_on_encod = self.bce_with_logits_loss( d_on_encoded, d_target_label_encoded) d_loss_on_encod.backward() self.opt_discr.step() #---------- Train the generator---------- self.opt_enc_dec.zero_grad() d_on_encoded_for_enc = self.discr(encoded_images) g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc, g_target_label_encoded) g_loss_enc = self.mse_loss(encoded_images, images) g_loss_dec = self.mse_loss(decoded_messages, messages) g_loss = self.config.adversarial_loss * g_loss_adv \ + self.config.encoder_loss * g_loss_enc \ + self.config.decoder_loss * g_loss_dec g_loss.backward() self.opt_enc_dec.step() decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip( 0, 1) bitwise_err = np.sum(np.abs(decoded_rounded - messages.detach().cpu().numpy())) \ / (batch_size * messages.shape[1]) losses = { 'loss': g_loss.item(), 'encoder_mse': g_loss_enc.item(), 'decoder_mse': g_loss_dec.item(), 'bitwise-error': bitwise_err, 'adversarial_bce': g_loss_adv.item(), 'discr_cover_bce': d_loss_on_cover.item(), 'discr_encod_bce': d_loss_on_encod.item() } return losses, (encoded_images, decoded_messages) def validate_on_batch(self, batch: list): '''Run validation on a batch consist of [images, messages]''' images, messages = batch batch_size = images.shape[0] self.enc_dec.eval() self.discr.eval() with torch.no_grad(): d_target_label_cover = torch.full((batch_size, 1), self.cover_label, device=self.device) d_target_label_encoded = torch.full((batch_size, 1), self.encod_label, device=self.device) g_target_label_encoded = torch.full((batch_size, 1), self.cover_label, device=self.device) d_on_cover = self.discr(images) d_loss_on_cover = self.bce_with_logits_loss( d_on_cover, d_target_label_cover) encoded_images, decoded_messages = self.enc_dec(images, messages) d_on_encoded = self.discr(encoded_images) d_loss_on_encod = self.bce_with_logits_loss( d_on_encoded, d_target_label_encoded) d_on_encoded_for_enc = self.discr(encoded_images) g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc, g_target_label_encoded) g_loss_enc = self.mse_loss(encoded_images, images) g_loss_dec = self.mse_loss(decoded_messages, messages) g_loss = self.config.adversarial_loss * g_loss_adv \ + self.config.encoder_loss * g_loss_enc \ + self.config.decoder_loss * g_loss_dec decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip( 0, 1) bitwise_err = np.sum(np.abs(decoded_rounded - messages.detach().cpu().numpy()))\ / (batch_size * messages.shape[1]) losses = { 'loss': g_loss.item(), 'encoder_mse': g_loss_enc.item(), 'decoder_mse': g_loss_dec.item(), 'bitwise-err': bitwise_err, 'adversarial_bce': g_loss_adv.item(), 'discr_cover_bce': d_loss_on_cover.item(), 'discr_enced_bce': d_loss_on_encod.item() } return losses, (encoded_images, decoded_messages) def to_stirng(self): return f'{str(self.enc_dec)}\n{str(self.discr)}'
def main(): # set torch and numpy seed for reproducibility torch.manual_seed(27) np.random.seed(27) # tensorboard writer writer = SummaryWriter(settings.TENSORBOARD_DIR) # makedir snapshot makedir(settings.CHECKPOINT_DIR) # enable cudnn torch.backends.cudnn.enabled = True # create segmentor network model_G = Segmentor(pretrained=settings.PRETRAINED, num_classes=settings.NUM_CLASSES, modality=settings.MODALITY) model_G.train() model_G.cuda() torch.backends.cudnn.benchmark = True # create discriminator network model_D = Discriminator(settings.NUM_CLASSES) model_D.train() model_D.cuda() # dataset and dataloader dataset = TrainDataset() dataloader = data.DataLoader(dataset, batch_size=settings.BATCH_SIZE, shuffle=True, num_workers=settings.NUM_WORKERS, pin_memory=True, drop_last=True) test_dataset = TestDataset(data_root=settings.DATA_ROOT_VAL, data_list=settings.DATA_LIST_VAL) test_dataloader = data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=settings.NUM_WORKERS, pin_memory=True) # optimizer for generator network (segmentor) optim_G = optim.SGD(model_G.optim_parameters(settings.LR), lr=settings.LR, momentum=settings.LR_MOMENTUM, weight_decay=settings.WEIGHT_DECAY) # lr scheduler for optimi_G lr_lambda_G = lambda epoch: (1 - epoch / settings.EPOCHS )**settings.LR_POLY_POWER lr_scheduler_G = optim.lr_scheduler.LambdaLR(optim_G, lr_lambda=lr_lambda_G) # optimizer for discriminator network optim_D = optim.Adam(model_D.parameters(), settings.LR_D) # lr scheduler for optimi_D lr_lambda_D = lambda epoch: (1 - epoch / settings.EPOCHS )**settings.LR_POLY_POWER lr_scheduler_D = optim.lr_scheduler.LambdaLR(optim_D, lr_lambda=lr_lambda_D) # losses ce_loss = CrossEntropyLoss2d( ignore_index=settings.IGNORE_LABEL) # to use for segmentor bce_loss = BCEWithLogitsLoss2d() # to use for discriminator # upsampling for the network output upsample = nn.Upsample(size=(settings.CROP_SIZE, settings.CROP_SIZE), mode='bilinear', align_corners=True) # # labels for adversarial training # pred_label = 0 # gt_label = 1 # load the model to resume training last_epoch = -1 if settings.RESUME_TRAIN: checkpoint = torch.load(settings.LAST_CHECKPOINT) model_G.load_state_dict(checkpoint['model_G_state_dict']) model_G.train() model_G.cuda() model_D.load_state_dict(checkpoint['model_D_state_dict']) model_D.train() model_D.cuda() optim_G.load_state_dict(checkpoint['optim_G_state_dict']) optim_D.load_state_dict(checkpoint['optim_D_state_dict']) lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G_state_dict']) lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D_state_dict']) last_epoch = checkpoint['epoch'] # purge the logs after the last_epoch writer = SummaryWriter(settings.TENSORBOARD_DIR, purge_step=(last_epoch + 1) * len(dataloader)) for epoch in range(last_epoch + 1, settings.EPOCHS + 1): train_one_epoch(model_G, model_D, optim_G, optim_D, dataloader, test_dataloader, epoch, upsample, ce_loss, bce_loss, writer, print_freq=5, eval_freq=settings.EVAL_FREQ) if epoch % settings.CHECKPOINT_FREQ == 0 and epoch != 0: save_checkpoint(epoch, model_G, model_D, optim_G, optim_D, lr_scheduler_G, lr_scheduler_D) # save the final model if epoch >= settings.EPOCHS: print('saving the final model') save_checkpoint(epoch, model_G, model_D, optim_G, optim_D, lr_scheduler_G, lr_scheduler_D) writer.close() lr_scheduler_G.step() lr_scheduler_D.step()
def main(): # parse input size h, w = map(int, args.input_size.split(',')) input_size = (h, w) # cudnn.enabled = True # gpu = args.gpu # create segmentation network model = DeepLab(num_classes=args.num_classes) # load pretrained parameters # if args.restore_from[:4] == 'http' : # saved_state_dict = model_zoo.load_url(args.restore_from) # else: # saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) # new_params = model.state_dict().copy() # for name, param in new_params.items(): # if name in saved_state_dict and param.size() == saved_state_dict[name].size(): # new_params[name].copy_(saved_state_dict[name]) # model.load_state_dict(new_params) model.train() model.cpu() # model.cuda(args.gpu) # cudnn.benchmark = True # create discriminator network model_D = Discriminator(num_classes=args.num_classes) # if args.restore_from_D is not None: # model_D.load_state_dict(torch.load(args.restore_from_D)) model_D.train() model_D.cpu() # model_D.cuda(args.gpu) # MILESTONE 1 print("Printing MODELS ...") print(model) print(model_D) # Create directory to save snapshots of the model if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) # Load train data and ground truth labels # train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, # scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) # train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, # scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) # trainloader = data.DataLoader(train_dataset, # batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=False) # trainloader_gt = data.DataLoader(train_gt_dataset, # batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=False) train_dataset = MyCustomDataset() train_gt_dataset = MyCustomDataset() trainloader = data.DataLoader(train_dataset, batch_size=5, shuffle=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=5, shuffle=True) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # MILESTONE 2 print("Printing Loaders") print(trainloader_iter) print(trainloader_gt_iter) # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() # MILESTONE 3 print("Printing OPTIMIZERS ...") print(optimizer) print(optimizer_D) # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) # labels for adversarial training pred_label = 0 gt_label = 1 for i_iter in range(args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # do semi first # if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv : # try: # _, batch = next(trainloader_remain_iter) # except: # trainloader_remain_iter = enumerate(trainloader_remain) # _, batch = next(trainloader_remain_iter) # # only access to img # images, _, _, _ = batch # images = Variable(images).cuda(args.gpu) # pred = interp(model(images)) # pred_remain = pred.detach() # D_out = interp(model_D(F.softmax(pred))) # D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1) # ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool) # loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, make_D_label(gt_label, ignore_mask_remain)) # loss_semi_adv = loss_semi_adv/args.iter_size # #loss_semi_adv.backward() # loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()/args.lambda_semi_adv # if args.lambda_semi <= 0 or i_iter < args.semi_start: # loss_semi_adv.backward() # loss_semi_value = 0 # else: # # produce ignore mask # semi_ignore_mask = (D_out_sigmoid < args.mask_T) # semi_gt = pred.data.cpu().numpy().argmax(axis=1) # semi_gt[semi_ignore_mask] = 255 # semi_ratio = 1.0 - float(semi_ignore_mask.sum())/semi_ignore_mask.size # print('semi ratio: {:.4f}'.format(semi_ratio)) # if semi_ratio == 0.0: # loss_semi_value += 0 # else: # semi_gt = torch.FloatTensor(semi_gt) # loss_semi = args.lambda_semi * loss_calc(pred, semi_gt, args.gpu) # loss_semi = loss_semi/args.iter_size # loss_semi_value += loss_semi.data.cpu().numpy()/args.lambda_semi # loss_semi += loss_semi_adv # loss_semi.backward() # else: # loss_semi = None # loss_semi_adv = None # train with source try: _, batch = next(trainloader_iter) except: trainloader_iter = enumerate(trainloader) _, batch = next(trainloader_iter) images, labels, _, _ = batch images = Variable(images).cpu() # images = Variable(images).cuda(args.gpu) ignore_mask = (labels.numpy() == 255) # segmentation prediction pred = interp(model(images)) # (spatial multi-class) cross entropy loss loss_seg = loss_calc(pred, labels) # loss_seg = loss_calc(pred, labels, args.gpu) # discriminator prediction D_out = interp(model_D(F.softmax(pred))) # adversarial loss loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask)) # multi-task loss # lambda_adv - weight for minimizing loss loss = loss_seg + args.lambda_adv_pred * loss_adv_pred # loss normalization loss = loss / args.iter_size # back propagation loss.backward() loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size loss_adv_pred_value += loss_adv_pred.data.cpu().numpy( ) / args.iter_size # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() # if args.D_remain: # pred = torch.cat((pred, pred_remain), 0) # ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0) D_out = interp(model_D(F.softmax(pred))) loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy() # train with gt # get gt labels try: _, batch = next(trainloader_gt_iter) except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = next(trainloader_gt_iter) _, labels_gt, _, _ = batch D_gt_v = Variable(one_hot(labels_gt)).cpu() # D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) ignore_mask_gt = (labels_gt.numpy() == 255) D_out = interp(model_D(D_gt_v)) loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy() optimizer.step() optimizer_D.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth')) end = timeit.default_timer() print(end - start, 'seconds')