def __init__(self, config): self.config = config self.device = config.device self.max_itr = config.max_itr self.batch_size = config.batch_size self.img_size = config.img_size self.dim_z = config.dim_z self.dim_c = config.dim_c self.scale = config.scale self.n_gen = config.n_gen self.start_itr = 1 dataloader = DataLoader( config.data_root, config.dataset_name, config.img_size, config.batch_size, config.with_label ) train_loader, test_loader = dataloader.get_loader(only_train=True) self.dataloader = train_loader self.dataloader = endless_dataloader(self.dataloader) self.generator = Generator(config).to(config.device) self.discriminator = Discriminator(config).to(config.device) self.optim_g = torch.optim.Adam(self.generator.parameters(), lr=config.lr_g, betas=(config.beta1, config.beta2)) self.optim_d = torch.optim.Adam(self.discriminator.parameters(), lr=config.lr_d, betas=(config.beta1, config.beta2)) self.criterion = GANLoss() if not self.config.checkpoint_path == '': self._load_models(self.config.checkpoint_path) self.x, self.y, self.r = get_coordinates(self.img_size, self.img_size, self.scale, self.batch_size) self.x, self.y, self.r = self.x.to(self.device), self.y.to(self.device), self.r.to(self.device) self.writer = SummaryWriter(log_dir=config.log_dir)
def __init__(self, args): self.args = args self.device = args.device self.start_iter = 1 self.train_iters = args.train_iters # coeffs self.lambda_A = args.lambda_A self.lambda_B = args.lambda_B self.lambda_idt = args.lambda_idt self.dataloader_A, self.dataloader_B = get_dataloader(args) self.D_B, self.G_AB = get_model(args) self.D_A, self.G_BA = get_model(args) self.criterion_GAN = GANLoss(use_lsgan=args.use_lsgan).to(args.device) self.criterion_cycle = nn.L1Loss() self.criterion_idt = nn.L1Loss() self.optimizer_D = torch.optim.Adam( itertools.chain(self.D_B.parameters(), self.D_A.parameters()), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay) self.optimizer_G = torch.optim.Adam( itertools.chain(self.G_AB.parameters(), self.G_BA.parameters()), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay) self.logger = self.get_logger(args) self.writer = SummaryWriter(args.log_dir) save_args(args.log_dir, args)
def __init__(self, **kwargs): self.netG_in_channels = kwargs['netG_in_channels'] self.netG_out_channels = kwargs['netG_out_channels'] self.phase = kwargs['phase'] self.device = kwargs['device'] self.gpus = [int(x) for x in list(kwargs['gpu'])] if self.phase == 'train': use_sigmoid = not kwargs['use_lsgan'] self.netG = resnet152_fpn(self.netG_in_channels, self.netG_out_channels, pretrained=False) self.netD = NLayerDiscriminator(self.netG_in_channels + self.netG_out_channels, 64, use_sigmoid=use_sigmoid, init_type='normal') if len(kwargs['gpu']) > 1: self.netG = nn.DataParallel(self.netG, device_ids=self.gpus) self.netD = nn.DataParallel(self.netD, device_ids=self.gpus) self.netG.to(self.device) self.netD.to(self.device) else: self.netG = resnet152_fpn(self.netG_in_channels, self.netG_out_channels, pretrained=False) print('Loading model from {}.'.format(kwargs['model_file'])) self.netG.load_state_dict(torch.load(kwargs['model_file'])) self.netG.to(self.device) self.netG.eval() if self.phase == 'train': # self.fake_AB_pool = ImagePool(kwargs['poolsize']) self.GANloss = GANLoss(self.device, use_lsgan=kwargs['use_lsgan']) self.L1loss = nn.L1Loss() self.lambda_L1 = kwargs['lambda_L1'] self.CEloss = nn.CrossEntropyLoss() self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=kwargs['lr'], betas=(0.5, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=kwargs['lr'], betas=(0.5, 0.999)) def lambda_rule(epoch): lr_l = 1.0 - max(0, epoch - kwargs['niter']) / float( kwargs['niter_decay'] + 1) return lr_l self.scheduler_G = torch.optim.lr_scheduler.LambdaLR( self.optimizer_G, lr_lambda=lambda_rule) self.scheduler_D = torch.optim.lr_scheduler.LambdaLR( self.optimizer_D, lr_lambda=lambda_rule)
def initialize(self, n_input_channels, n_output_channels, n_blocks, initial_filters, dropout_value, lr, batch_size, image_width, image_height, gpu_ids, gan, pool_size, n_blocks_discr): self.input_img = self.tensor(batch_size, n_input_channels, image_height, image_width) self.input_gt = self.tensor(batch_size, n_output_channels, image_height, image_width) self.generator = UNetV2(n_input_channels, n_output_channels, n_blocks, initial_filters, gpu_ids=gpu_ids) if gan: self.discriminator = ImageDiscriminatorConv( n_output_channels, initial_filters, dropout_value, gpu_ids=gpu_ids, n_blocks=n_blocks_discr) self.criterion_gan = GANLoss(tensor=self.tensor) self.optimizer_dis = torch.optim.Adam( self.discriminator.parameters(), lr=lr, betas=(0.5, 0.999)) self.fake_mask_pool = ImagePool(pool_size) if self.load_network: self._load_network(self.generator, 'Model', self.load_epoch) if gan: self._load_network(self.discriminator, 'Discriminator', self.load_epoch) self.criterion_seg = BCELoss2d() self.optimizer_seg = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(0.5, 0.999)) print('---------- Network initialized -------------') self.print_network(self.generator) if gan: self.print_network(self.discriminator) print('-----------------------------------------------')
def __init__(self, decoder, device=None, model=None, dim_rand=30, n_cls=10, learning_rate=1e-3, act=F.relu): # Settings self.device = device self.dim_rand = dim_rand self.n_cls = n_cls self.act = act self.learning_rate = 1e-5 # Model generator, discriminator = create_gan_experiment(model=model, act=act, dim_rand=dim_rand) self.generator = generator self.generator.to_gpu(device) if self.device else None self.discriminator = discriminator self.discriminator.to_gpu(device) if self.device else None self.decoder = decoder # Optimizer self.optimizer_gen = optimizers.Adam(learning_rate) self.optimizer_gen.setup(self.generator) self.optimizer_gen.use_cleargrads() self.optimizer_dis = optimizers.Adam(learning_rate) self.optimizer_dis.setup(self.discriminator) self.optimizer_dis.use_cleargrads() # Losses self.gan_loss = GANLoss()
def __init__(self, checkpoints_dir, lr=0.0002, niter_decay=45, batch_size=4, gpu_ids=[0, 1], isTrain=True): #Hyperparams self.lr = lr self.beta1 = 0.5 self.niter_decay = niter_decay self.input_nc = 11 #number of input channels self.output_nc = 3 #number of output channels self.label_nc = 11 #number of mask channels self.isTrain = isTrain #Whether to train self.dis_net_input_nc = self.input_nc + self.output_nc self.dis_n_layers = 3 self.num_D = 2 self.lambda_feat = 10.0 self.z_dim = 512 self.batch_size = batch_size self.gpu_ids = gpu_ids self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor #Loss Function parameters - used in init_loss_funtion self.use_gan_feat_loss = True self.no_vgg_loss = True self.no_l2_loss = True self.checkpoints_dir = checkpoints_dir #Optimization Parameters self.use_lsgan = False self.no_ganFeat_loss = True self.gen_net = GeneratorNetwork(self.input_nc, self.output_nc) if len(gpu_ids) > 0: self.gen_net.cuda(gpu_ids[0]) self.gen_net.apply(weights_init) if self.isTrain: use_sigmoid = True self.dis_net = DiscriminatorNetwork(self.dis_net_input_nc, self.dis_n_layers, self.num_D, use_sigmoid) if len(gpu_ids) > 0: self.dis_net.cuda(gpu_ids[0]) self.dis_net.apply(weights_init) #Dont know why we need this??? self.dis_net2 = DiscriminatorNetwork(self.dis_net_input_nc, self.dis_n_layers, self.num_D, use_sigmoid) if len(gpu_ids) > 0: self.dis_net2.cuda(gpu_ids[0]) self.dis_net2.apply(weights_init) # self.p_net = PNetwork(self.label_nc, self.output_nc) # self.p_net.apply(weights_init) self.p_net = PNetwork(self.batch_size, self.checkpoints_dir) #TODO longSize = 256 n_downsample_global = 2 embed_feature_size = longSize // 2**n_downsample_global self.encoder_skin_net = EncoderGenerator_mask_skin( functools.partial(nn.BatchNorm2d, affine=True)) if len(gpu_ids) > 0: self.encoder_skin_net.cuda(gpu_ids[0]) self.encoder_skin_net.apply(weights_init) self.encoder_hair_net = EncoderGenerator_mask_skin( functools.partial(nn.BatchNorm2d, affine=True)) if len(gpu_ids) > 0: self.encoder_hair_net.cuda(gpu_ids[0]) self.encoder_hair_net.apply(weights_init) self.encoder_left_eye_net = EncoderGenerator_mask_eye( functools.partial(nn.BatchNorm2d, affine=True)) if len(gpu_ids) > 0: self.encoder_left_eye_net.cuda(gpu_ids[0]) self.encoder_left_eye_net.apply(weights_init) self.encoder_right_eye_net = EncoderGenerator_mask_eye( functools.partial(nn.BatchNorm2d, affine=True)) if len(gpu_ids) > 0: self.encoder_right_eye_net.cuda(gpu_ids[0]) self.encoder_right_eye_net.apply(weights_init) self.encoder_mouth_net = EncoderGenerator_mask_mouth( functools.partial(nn.BatchNorm2d, affine=True)) if len(gpu_ids) > 0: self.encoder_mouth_net.cuda(gpu_ids[0]) self.encoder_mouth_net.apply(weights_init) self.decoder_skin_net = DecoderGenerator_mask_skin( functools.partial(nn.BatchNorm2d, affine=True)) if len(gpu_ids) > 0: self.decoder_skin_net.cuda(gpu_ids[0]) self.decoder_skin_net.apply(weights_init) self.decoder_hair_net = DecoderGenerator_mask_skin( functools.partial(nn.BatchNorm2d, affine=True)) if len(gpu_ids) > 0: self.decoder_hair_net.cuda(gpu_ids[0]) self.decoder_hair_net.apply(weights_init) self.decoder_left_eye_net = DecoderGenerator_mask_eye( functools.partial(nn.BatchNorm2d, affine=True)) if len(gpu_ids) > 0: self.decoder_left_eye_net.cuda(gpu_ids[0]) self.decoder_left_eye_net.apply(weights_init) self.decoder_right_eye_net = DecoderGenerator_mask_eye( functools.partial(nn.BatchNorm2d, affine=True)) if len(gpu_ids) > 0: self.decoder_right_eye_net.cuda(gpu_ids[0]) self.decoder_right_eye_net.apply(weights_init) self.decoder_mouth_net = DecoderGenerator_mask_mouth( functools.partial(nn.BatchNorm2d, affine=True)) if len(gpu_ids) > 0: self.decoder_mouth_net.cuda(gpu_ids[0]) self.decoder_mouth_net.apply(weights_init) self.decoder_skin_image_net = DecoderGenerator_mask_skin_image( functools.partial(nn.BatchNorm2d, affine=True)) if len(gpu_ids) > 0: self.decoder_skin_image_net.cuda(gpu_ids[0]) self.decoder_skin_image_net.apply(weights_init) self.decoder_hair_image_net = DecoderGenerator_mask_skin_image( functools.partial(nn.BatchNorm2d, affine=True)) if len(gpu_ids) > 0: self.decoder_hair_image_net.cuda(gpu_ids[0]) self.decoder_hair_image_net.apply(weights_init) self.decoder_left_eye_image_net = DecoderGenerator_mask_eye_image( functools.partial(nn.BatchNorm2d, affine=True)) if len(gpu_ids) > 0: self.decoder_left_eye_image_net.cuda(gpu_ids[0]) self.decoder_left_eye_image_net.apply(weights_init) self.decoder_right_eye_image_net = DecoderGenerator_mask_eye_image( functools.partial(nn.BatchNorm2d, affine=True)) if len(gpu_ids) > 0: self.decoder_right_eye_image_net.cuda(gpu_ids[0]) self.decoder_right_eye_image_net.apply(weights_init) self.decoder_mouth_image_net = DecoderGenerator_mask_mouth_image( functools.partial(nn.BatchNorm2d, affine=True)) if len(gpu_ids) > 0: self.decoder_mouth_image_net.cuda(gpu_ids[0]) self.decoder_mouth_image_net.apply(weights_init) if self.isTrain: self.loss_filter = self.init_loss_filter(self.no_ganFeat_loss, self.no_vgg_loss, self.no_l2_loss) self.old_lr = self.lr self.criterionGAN = GANLoss(use_lsgan=self.use_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() self.criterionL2 = torch.nn.MSELoss() self.criterionL1 = torch.nn.L1Loss() # self.criterionMFM = MFMLoss() weight_list = [0.2, 1, 5, 5, 5, 5, 3, 8, 8, 8, 1] self.criterionCrossEntropy = torch.nn.CrossEntropyLoss( weight=torch.FloatTensor(weight_list)) # if self.no_vgg_loss: # self.criterionVGG = VGGLoss(weights=None) # self.criterionGM = GramMatrixLoss() print(self.loss_filter) self.loss_names = self.loss_filter('KL_embed', 'L2_mask_image', 'G_GAN', 'G_GAN_Feat', 'D_real', 'D_fake', 'L2_image', 'G2_GAN', 'D2_real', 'D2_fake') params_decoder = list(self.decoder_skin_net.parameters()) + list( self.decoder_hair_net.parameters()) + list( self.decoder_left_eye_net.parameters()) + list( self.decoder_right_eye_net.parameters()) + list( self.decoder_mouth_net.parameters()) params_image_decoder = list(self.decoder_skin_image_net.parameters( )) + list(self.decoder_hair_image_net.parameters()) + list( self.decoder_left_eye_image_net.parameters()) + list( self.decoder_right_eye_image_net.parameters()) + list( self.decoder_mouth_image_net.parameters()) params_encoder = list(self.encoder_skin_net.parameters()) + list( self.encoder_hair_net.parameters()) + list( self.encoder_left_eye_net.parameters()) + list( self.encoder_right_eye_net.parameters()) + list( self.encoder_mouth_net.parameters()) params_together = list(self.gen_net.parameters( )) + params_decoder + params_encoder + params_image_decoder self.optimizer_G_together = torch.optim.Adam(params_together, lr=self.lr, betas=(self.beta1, 0.999)) params = list(self.dis_net.parameters()) # self.optimizer_D = torch.optim.Adam(params, lr=self.lr, betas=(self.beta1, 0.999)) self.optimizer_D = torch.optim.RMSprop(params, lr=self.lr) # optimizer D2 params = list(self.dis_net2.parameters()) # self.optimizer_D2 = torch.optim.Adam(params, lr=self.lr, betas=(self.beta1, 0.999)) self.optimizer_D2 = torch.optim.RMSprop(params, lr=self.lr)
def main(): args = args_initialize() save_freq = args.save_freq epochs = args.num_epoch cuda = args.cuda train_dataset = UnalignedDataset(is_train=True) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0 ) net_G_A = ResNetGenerator(input_nc=3, output_nc=3) net_G_B = ResNetGenerator(input_nc=3, output_nc=3) net_D_A = Discriminator() net_D_B = Discriminator() if args.cuda: net_G_A = net_G_A.cuda() net_G_B = net_G_B.cuda() net_D_A = net_D_A.cuda() net_D_B = net_D_B.cuda() fake_A_pool = ImagePool(50) fake_B_pool = ImagePool(50) criterionGAN = GANLoss(cuda=cuda) criterionCycle = torch.nn.L1Loss() criterionIdt = torch.nn.L1Loss() optimizer_G = torch.optim.Adam( itertools.chain(net_G_A.parameters(), net_G_B.parameters()), lr=args.lr, betas=(args.beta1, 0.999) ) optimizer_D_A = torch.optim.Adam(net_D_A.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) optimizer_D_B = torch.optim.Adam(net_D_B.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) log_dir = './logs' checkpoints_dir = './checkpoints' os.makedirs(log_dir, exist_ok=True) os.makedirs(checkpoints_dir, exist_ok=True) writer = SummaryWriter(log_dir) for epoch in range(epochs): running_loss = np.zeros((8)) for batch_idx, data in enumerate(train_loader): input_A = data['A'] input_B = data['B'] if cuda: input_A = input_A.cuda() input_B = input_B.cuda() real_A = Variable(input_A) real_B = Variable(input_B) """ Backward net_G """ optimizer_G.zero_grad() lambda_idt = 0.5 lambda_A = 10.0 lambda_B = 10.0 # 各 Generatorに変換後の画像を入力 # 何もしないのが理想の出力 idt_B = net_G_A(real_B) loss_idt_A = criterionIdt(idt_B, real_B) * lambda_B * lambda_idt idt_A = net_G_B(real_A) loss_idt_B = criterionIdt(idt_A, real_A) * lambda_A * lambda_idt # GAN loss = D_A(G_A(A)) # G_Aとしては生成した偽物画像が本物(True)と判断して欲しい fake_B = net_G_A(real_A) pred_fake = net_D_A(fake_B) loss_G_A = criterionGAN(pred_fake, True) fake_A = net_G_B(real_B) pred_fake = net_D_B(fake_A) loss_G_B = criterionGAN(pred_fake, True) rec_A = net_G_B(fake_B) loss_cycle_A = criterionCycle(rec_A, real_A) * lambda_A rec_B = net_G_A(fake_A) loss_cycle_B = criterionCycle(rec_B, real_B) * lambda_B loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B loss_G.backward() optimizer_G.step() """ update D_A """ optimizer_D_A.zero_grad() fake_B = fake_B_pool.query(fake_B.data) pred_real = net_D_A(real_B) loss_D_real = criterionGAN(pred_real, True) pred_fake = net_D_A(fake_B.detach()) loss_D_fake = criterionGAN(pred_fake, False) loss_D_A = (loss_D_real + loss_D_fake) * 0.5 loss_D_A.backward() optimizer_D_A.step() """ update D_B """ optimizer_D_B.zero_grad() fake_A = fake_A_pool.query(fake_A.data) pred_real = net_D_B(real_A) loss_D_real = criterionGAN(pred_real, True) pred_fake = net_D_B(fake_A.detach()) loss_D_fake = criterionGAN(pred_fake, False) loss_D_B = (loss_D_real + loss_D_fake) * 0.5 loss_D_B.backward() optimizer_D_B.step() ret_loss = np.array([ loss_G_A.data.detach().cpu().numpy(), loss_D_A.data.detach().cpu().numpy(), loss_G_B.data.detach().cpu().numpy(), loss_D_B.data.detach().cpu().numpy(), loss_cycle_A.data.detach().cpu().numpy(), loss_cycle_B.data.detach().cpu().numpy(), loss_idt_A.data.detach().cpu().numpy(), loss_idt_B.data.detach().cpu().numpy() ]) running_loss += ret_loss """ Save checkpoints """ if (epoch + 1) % save_freq == 0: save_network(net_G_A, 'G_A', str(epoch + 1)) save_network(net_D_A, 'D_A', str(epoch + 1)) save_network(net_G_B, 'G_B', str(epoch + 1)) save_network(net_D_B, 'D_B', str(epoch + 1)) running_loss /= len(train_loader) losses = running_loss print('epoch %d, losses: %s' % (epoch + 1, running_loss)) writer.add_scalar('loss_G_A', losses[0], epoch) writer.add_scalar('loss_D_A', losses[1], epoch) writer.add_scalar('loss_G_B', losses[2], epoch) writer.add_scalar('loss_D_B', losses[3], epoch) writer.add_scalar('loss_cycle_A', losses[4], epoch) writer.add_scalar('loss_cycle_B', losses[5], epoch) writer.add_scalar('loss_idt_A', losses[6], epoch) writer.add_scalar('loss_idt_B', losses[7], epoch)
def __init__(self, params_model): super(SEM_PCYC, self).__init__() print('Initializing model variables...', end='') # Dimension of embedding self.dim_out = params_model['dim_out'] # Dimension of semantic embedding self.sem_dim = params_model['sem_dim'] # Number of classes self.num_clss = params_model['num_clss'] # Sketch model: pre-trained on ImageNet self.sketch_model = VGGNetFeats(pretrained=False, finetune=False) self.load_weight(self.sketch_model, params_model['path_sketch_model'], 'sketch') # Image model: pre-trained on ImageNet self.image_model = VGGNetFeats(pretrained=False, finetune=False) self.load_weight(self.image_model, params_model['path_image_model'], 'image') # Semantic model embedding self.sem = [] for f in params_model['files_semantic_labels']: self.sem.append(np.load(f, allow_pickle=True).item()) self.dict_clss = params_model['dict_clss'] print('Done') print('Initializing trainable models...', end='') # Generators # Sketch to semantic generator self.gen_sk2se = Generator(in_dim=512, out_dim=self.dim_out, noise=False, use_dropout=True) # Image to semantic generator self.gen_im2se = Generator(in_dim=512, out_dim=self.dim_out, noise=False, use_dropout=True) # Semantic to sketch generator self.gen_se2sk = Generator(in_dim=self.dim_out, out_dim=512, noise=False, use_dropout=True) # Semantic to image generator self.gen_se2im = Generator(in_dim=self.dim_out, out_dim=512, noise=False, use_dropout=True) # Discriminators # Common semantic discriminator self.disc_se = Discriminator(in_dim=self.dim_out, noise=True, use_batchnorm=True) # Sketch discriminator self.disc_sk = Discriminator(in_dim=512, noise=True, use_batchnorm=True) # Image discriminator self.disc_im = Discriminator(in_dim=512, noise=True, use_batchnorm=True) # Semantic autoencoder self.aut_enc = AutoEncoder(dim=self.sem_dim, hid_dim=self.dim_out, nlayer=1) # Classifiers self.classifier_sk = nn.Linear(512, self.num_clss, bias=False) self.classifier_im = nn.Linear(512, self.num_clss, bias=False) self.classifier_se = nn.Linear(self.dim_out, self.num_clss, bias=False) for param in self.classifier_sk.parameters(): param.requires_grad = False for param in self.classifier_im.parameters(): param.requires_grad = False for param in self.classifier_se.parameters(): param.requires_grad = False print('Done') # Optimizers print('Defining optimizers...', end='') self.lr = params_model['lr'] self.gamma = params_model['gamma'] self.momentum = params_model['momentum'] self.milestones = params_model['milestones'] self.optimizer_gen = optim.Adam(list(self.gen_sk2se.parameters()) + list(self.gen_im2se.parameters()) + list(self.gen_se2sk.parameters()) + list(self.gen_se2im.parameters()), lr=self.lr) self.optimizer_disc = optim.SGD(list(self.disc_se.parameters()) + list(self.disc_sk.parameters()) + list(self.disc_im.parameters()), lr=self.lr, momentum=self.momentum) self.optimizer_ae = optim.SGD(self.aut_enc.parameters(), lr=100 * self.lr, momentum=self.momentum) self.scheduler_gen = optim.lr_scheduler.MultiStepLR( self.optimizer_gen, milestones=self.milestones, gamma=self.gamma) self.scheduler_disc = optim.lr_scheduler.MultiStepLR( self.optimizer_disc, milestones=self.milestones, gamma=self.gamma) self.scheduler_ae = optim.lr_scheduler.MultiStepLR( self.optimizer_ae, milestones=self.milestones, gamma=self.gamma) print('Done') # Loss function print('Defining losses...', end='') self.lambda_se = params_model['lambda_se'] self.lambda_im = params_model['lambda_im'] self.lambda_sk = params_model['lambda_sk'] self.lambda_gen_cyc = params_model['lambda_gen_cyc'] self.lambda_gen_adv = params_model['lambda_gen_adv'] self.lambda_gen_cls = params_model['lambda_gen_cls'] self.lambda_gen_reg = params_model['lambda_gen_reg'] self.lambda_disc_se = params_model['lambda_disc_se'] self.lambda_disc_sk = params_model['lambda_disc_sk'] self.lambda_disc_im = params_model['lambda_disc_im'] self.lambda_regular = params_model['lambda_regular'] self.criterion_gan = GANLoss(use_lsgan=True) self.criterion_cyc = nn.L1Loss() self.criterion_cls = nn.CrossEntropyLoss() self.criterion_reg = nn.MSELoss() print('Done') # Intermediate variables print('Initializing variables...', end='') self.sk_fe = torch.zeros(1) self.sk_em = torch.zeros(1) self.im_fe = torch.zeros(1) self.im_em = torch.zeros(1) self.se_em_enc = torch.zeros(1) self.se_em_rec = torch.zeros(1) self.im2se_em = torch.zeros(1) self.sk2se_em = torch.zeros(1) self.se2im_em = torch.zeros(1) self.se2sk_em = torch.zeros(1) self.im_em_hat = torch.zeros(1) self.sk_em_hat = torch.zeros(1) self.se_em_hat1 = torch.zeros(1) self.se_em_hat2 = torch.zeros(1) print('Done')
def __init__(self, opt): super(TwoStreamAE_mask, self).__init__(opt) if opt.resize_or_crop != 'none': torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain self.which_stream = opt.which_stream self.use_gan = opt.use_gan self.which_gan = opt.which_gan self.gan_weight = opt.gan_weight self.rec_weight = opt.rec_weight self.cond_in = opt.cond_in self.use_output_gate = opt.use_output_gate self.opt = opt if opt.no_comb: from MaskTwoStreamConvSwitch_NET import MaskTwoStreamConvSwitch_NET as model_factory else: from MaskTwoStreamConv_NET import MaskTwoStreamConv_NET as model_factory model = self.get_model(model_factory) self.netG = model(opt) self.netG.initialize() # move networsk to gpu if len(opt.gpu_ids) > 0: assert(torch.cuda.is_available()) self.netG.cudafy(opt.gpu_ids[0]) print('---------- Networks initialized -------------') # set loss functions and optimizers if self.isTrain: self.old_lr = opt.lr # defaine loss functions self.criterionRecon = MaskReconLoss() if opt.objReconLoss == 'l1': self.criterionObjRecon = nn.L1Loss() elif opt.objReconLoss == 'bce': self.criterionObjRecon = nn.BCELoss() else: self.criterionObjRecon = None # Names so we can breakout loss self.loss_names = ['G_Recon_comb', 'G_Recon_obj', \ 'KL_loss', 'loss_G_GAN', 'loss_D_GAN', 'loss_G_GAN_Feat'] params = self.netG.trainable_parameters self.optimizer = torch.optim.Adam(params, lr=opt.lr, \ betas=(opt.beta1, opt.beta2)) ########## define discriminator if self.use_gan: label_nc = opt.label_nc if not (opt.cond_in=='ctx_obj') \ else opt.label_nc * 2 if self.which_gan=='patch': use_lsgan=False self.netD = NLayerDiscriminator( \ input_nc=1+label_nc, ndf=opt.ndf, n_layers=opt.num_layers_D, norm_layer=opt.norm_layer, use_sigmoid=not use_lsgan, getIntermFeat=False) elif self.which_gan=='patch_res': use_lsgan=False self.netD = NLayerResDiscriminator( \ input_nc=1+label_nc, ndf=opt.ndf, n_layers=opt.num_layers_D, norm_layer=opt.norm_layer, use_sigmoid=not use_lsgan, getIntermFeat=False) elif self.which_gan=='patch_multiscale': use_lsgan=True self.netD = MultiscaleDiscriminator( 1+label_nc, opt.ndf, opt.num_layers_D, opt.norm_layer, not use_lsgan, 2, True) self.ganloss = GANLoss(use_lsgan=use_lsgan, tensor=self.Tensor) if opt.use_ganFeat_loss: self.criterionFeat = torch.nn.L1Loss() if len(opt.gpu_ids) > 0: self.netD.cuda(opt.gpu_ids[0]) params_D = [param for param in self.netD.parameters() \ if param.requires_grad] self.optimizer_D = torch.optim.Adam( params_D, lr=opt.lr, betas=(opt.beta1, 0.999)) # load networks if opt.continue_train or opt.load_pretrain: pretrained_path = '' if not self.isTrain else opt.load_pretrain self.load_network_dict( self.netG.params_dict, self.optimizer, 'G', opt.which_epoch, opt.load_pretrain) if opt.use_gan: # TODO(sh): add loading for discriminator optimizer self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) else: self.load_network_dict( self.netG.params_dict, None, 'G', opt.which_epoch, '')
def __init__(self, args): self.args = args args.logger.info('Initializing trainer') # if not os.path.isdir('../predict'): only used in validation # os.makedirs('../predict') self.model = get_model(args) if self.args.lock_coarse: for p in self.model.coarse_model.parameters(): p.requires_grad = False torch.cuda.set_device(args.rank) self.model.cuda(args.rank) self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[args.rank]) train_dataset, val_dataset = get_dataset(args) if not args.val: # train loss self.coarse_RGBLoss = RGBLoss(args, sharp=False) self.refine_RGBLoss = RGBLoss(args, sharp=False, refine=True) self.SegLoss = nn.CrossEntropyLoss() self.GANLoss = GANLoss(tensor=torch.FloatTensor) self.coarse_RGBLoss.cuda(args.rank) self.refine_RGBLoss.cuda(args.rank) self.SegLoss.cuda(args.rank) self.GANLoss.cuda(args.rank) if args.optimizer == "adamax": self.optG = torch.optim.Adamax(list(self.model.module.coarse_model.parameters()) + list(self.model.module.refine_model.parameters()), lr=args.learning_rate) elif args.optimizer == "adam": self.optG = torch.optim.Adam(self.model.parameters(), lr=args.learning_rate) elif args.optimizer == "sgd": self.optG = torch.optim.SGD(self.model.parameters(), lr=args.learning_rate, momentum=0.9) # self.optD = torch.optim.Adam(self.model.module.discriminator.parameters(), lr=args.learning_rate) self.optD = torch.optim.SGD(self.model.module.discriminator.parameters(), lr=args.learning_rate, momentum=0.9) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) self.train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size//args.gpus, shuffle=False, num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) else: # val criteria self.L1Loss = nn.L1Loss().cuda(args.rank) self.PSNRLoss = PSNR().cuda(args.rank) self.SSIMLoss = SSIM().cuda(args.rank) self.IoULoss = IoU().cuda(args.rank) self.VGGCosLoss = VGGCosineLoss().cuda(args.rank) val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) self.val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size//args.gpus, shuffle=False, num_workers=args.num_workers, pin_memory=True, sampler=val_sampler) torch.backends.cudnn.benchmark = True self.global_step = 0 self.epoch=1 if args.resume or (args.val and not args.checkepoch_range): self.load_checkpoint() if args.rank == 0: if args.val: self.writer = SummaryWriter(args.path+'/val_logs') if args.interval == 2 else\ SummaryWriter(args.path+'/val_int_1_logs') else: self.writer = SummaryWriter(args.path+'/logs') self.heatmap = self.create_stand_heatmap()
class RefinerGAN: def __init__(self, args): self.args = args args.logger.info('Initializing trainer') # if not os.path.isdir('../predict'): only used in validation # os.makedirs('../predict') self.model = get_model(args) if self.args.lock_coarse: for p in self.model.coarse_model.parameters(): p.requires_grad = False torch.cuda.set_device(args.rank) self.model.cuda(args.rank) self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[args.rank]) train_dataset, val_dataset = get_dataset(args) if not args.val: # train loss self.coarse_RGBLoss = RGBLoss(args, sharp=False) self.refine_RGBLoss = RGBLoss(args, sharp=False, refine=True) self.SegLoss = nn.CrossEntropyLoss() self.GANLoss = GANLoss(tensor=torch.FloatTensor) self.coarse_RGBLoss.cuda(args.rank) self.refine_RGBLoss.cuda(args.rank) self.SegLoss.cuda(args.rank) self.GANLoss.cuda(args.rank) if args.optimizer == "adamax": self.optG = torch.optim.Adamax(list(self.model.module.coarse_model.parameters()) + list(self.model.module.refine_model.parameters()), lr=args.learning_rate) elif args.optimizer == "adam": self.optG = torch.optim.Adam(self.model.parameters(), lr=args.learning_rate) elif args.optimizer == "sgd": self.optG = torch.optim.SGD(self.model.parameters(), lr=args.learning_rate, momentum=0.9) # self.optD = torch.optim.Adam(self.model.module.discriminator.parameters(), lr=args.learning_rate) self.optD = torch.optim.SGD(self.model.module.discriminator.parameters(), lr=args.learning_rate, momentum=0.9) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) self.train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size//args.gpus, shuffle=False, num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) else: # val criteria self.L1Loss = nn.L1Loss().cuda(args.rank) self.PSNRLoss = PSNR().cuda(args.rank) self.SSIMLoss = SSIM().cuda(args.rank) self.IoULoss = IoU().cuda(args.rank) self.VGGCosLoss = VGGCosineLoss().cuda(args.rank) val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) self.val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size//args.gpus, shuffle=False, num_workers=args.num_workers, pin_memory=True, sampler=val_sampler) torch.backends.cudnn.benchmark = True self.global_step = 0 self.epoch=1 if args.resume or (args.val and not args.checkepoch_range): self.load_checkpoint() if args.rank == 0: if args.val: self.writer = SummaryWriter(args.path+'/val_logs') if args.interval == 2 else\ SummaryWriter(args.path+'/val_int_1_logs') else: self.writer = SummaryWriter(args.path+'/logs') self.heatmap = self.create_stand_heatmap() def prepare_heat_map(self, prob_map): bs, c, h, w = prob_map.size() if h!=128: prob_map_ = F.interpolate(prob_map, size=(128, 256), mode='nearest', align_corners=True) return prob_map def create_heatmap(self, prob_map): c, h, w = prob_map.size() assert c==1, c assert h==128, h rgb_prob_map = torch.zeros(3, h, w) minimum, maximum = 0.0, 1.0 ratio = 2 * (prob_map-minimum) / (maximum - minimum) rgb_prob_map[0] = 1-ratio rgb_prob_map[1] = ratio-1 rgb_prob_map[:2].clamp_(0,1) rgb_prob_map[2] = 1-rgb_prob_map[0]-rgb_prob_map[1] return rgb_prob_map def create_stand_heatmap(self): heatmap = torch.zeros(3, 128, 256) for i in range(256): heatmap[0, :, i] = max(0, 1 - 2.*i/256) heatmap[1, :, i] = max(0, 2.*i/256 - 1) heatmap[2, :, i] = 1-heatmap[0, :, i]-heatmap[1, :, i] return heatmap def set_epoch(self, epoch): self.args.logger.info("Start of epoch %d" % (epoch+1)) self.epoch = epoch + 1 self.train_loader.sampler.set_epoch(epoch) # self.val_loader.sampler.set_epoch(epoch) def get_input(self, data): if self.args.mode == 'xs2xs': if self.args.syn_type == 'extra': x = torch.cat([data['frame1'], data['frame2'], data['seg1'], data['seg2']], dim=1) mask = torch.cat([data['fg_mask1'],data['fg_mask2']], dim=1) gt = torch.cat([data['frame3'], data['seg3']], dim=1) else: x = torch.cat([data['frame1'], data['frame3'], data['seg1'], data['seg3']], dim=1) mask = torch.cat([data['fg_mask1'],data['fg_mask3']], dim=1) gt = torch.cat([data['frame2'], data['seg2']], dim=1) elif self.args.mode == 'xss2x': if self.args.syn_type == 'extra': x = torch.cat([data['frame1'], data['frame2'], data['seg1'], data['seg2'], data['seg3']], dim=1) gt = data['frame3'] else: x = torch.cat([data['frame1'], data['frame3'], data['seg1'], data['seg2'], data['seg3']], dim=1) gt = data['frame2'] return x, mask, gt def normalize(self, img): return (img+1)/2 def prepare_image_set(self, data, coarse_img, refined_imgs, seg, pred_fake, pred_real, extra=False): view_rgbs = [ self.normalize(data['frame1'][0]), self.normalize(data['frame2'][0]), self.normalize(data['frame3'][0]) ] view_segs = [vis_seg_mask(data['seg'+str(i)][0].unsqueeze(0), 20).squeeze(0) for i in range(1, 4)] # gan view_probs = [] view_probs.append(self.heatmap) for i in range(self.args.num_D): toDraw = F.interpolate(pred_real[i][-1][0].unsqueeze(0).cpu(), (128, 256), mode='bilinear', align_corners=True).squeeze(0) view_probs.append(self.create_heatmap(toDraw)) toDraw = F.interpolate(pred_fake[i][-1][0].unsqueeze(0).cpu(), (128, 256), mode='bilinear', align_corners=True).squeeze(0) view_probs.append(self.create_heatmap(toDraw)) if not extra: # coarse pred_rgb = self.normalize(coarse_img[0]) pred_seg = vis_seg_mask(seg[0].unsqueeze(0), 20).squeeze(0) if self.args.mode == 'xs2xs' else torch.zeros_like(view_segs[0]) insert_index = 2 if self.args.syn_type == 'inter' else 3 view_rgbs.insert(insert_index, pred_rgb) view_segs.insert(insert_index, pred_seg) view_segs.append(torch.zeros_like(view_segs[-1])) # refine refined_bs_imgs = [ refined_img[0].unsqueeze(0) for refined_img in refined_imgs ] for i in range(self.args.n_scales): insert_img = F.interpolate(refined_bs_imgs[i], size=(128,256))[0].clamp_(-1, 1) pred_rgb = self.normalize(insert_img) insert_ind = insert_index + i+1 view_rgbs.insert(insert_ind, pred_rgb) write_in_img = make_grid(view_rgbs + view_segs + view_probs, nrow=4+self.args.n_scales) # else: # view_rgbs.insert(3, torch.zeros_like(view_rgbs[-1])) # view_segs.insert(3, torch.zeros_like(view_segs[-1])) # view_pred_rgbs = [] # view_pred_segs = [] # for i in range(self.args.extra_length): # pred_rgb = self.normalize(img[i][0].cpu()) # pred_seg = vis_seg_mask(seg[i].cpu(), 20).squeeze(0) if self.args.mode == 'xs2xs' else torch.zeros_like(view_segs[0]) # view_pred_rgbs.append(pred_rgb) # view_pred_segs.append(pred_seg) # write_in_img = make_grid(view_rgbs + view_segs + view_pred_rgbs + view_pred_segs, nrow=4) return write_in_img def train(self): self.args.logger.info('Training started') self.model.train() end = time() load_time = 0 comp_time = 0 for step, data in enumerate(self.train_loader): self.step = step load_time += time() - end end = time() # for tensorboard self.global_step += 1 # forward pass x, fg_mask, gt = self.get_input(data) x = x.cuda(self.args.rank, non_blocking=True) fg_mask = fg_mask.cuda(self.args.rank, non_blocking=True) gt = gt.cuda(self.args.rank, non_blocking=True) coarse_img, refined_imgs, seg, pred_fake_D, pred_real_D, pred_fake_G = self.model(x, fg_mask, gt) if not self.args.lock_coarse: loss_dict = self.coarse_RGBLoss(coarse_img, gt[:, :3], False) if self.args.mode == 'xs2xs': loss_dict['ce_loss'] = self.args.ce_weight*self.SegLoss(seg, torch.argmax(gt[:,3:], dim=1)) else: loss_dict = OrderedDict() for i in range(self.args.n_scales): # print(i, refined_imgs[-i].size()) loss_dict.update(self.refine_RGBLoss(refined_imgs[-i-1], F.interpolate(gt[:,:3], scale_factor=(1/2)**i, mode='bilinear', align_corners=True),\ refine_scale=1/(2**i), step=self.global_step, normed=False)) # loss and accuracy loss = 0 for i in loss_dict.values(): loss += torch.mean(i) loss_dict['loss_all'] = loss if self.global_step > 1000: loss_dict['adv_loss'] = self.args.refine_adv_weight*self.GANLoss(pred_fake_G, True) g_loss = loss_dict['loss_all'] + loss_dict['adv_loss'] loss_dict['d_real_loss'] = self.args.refine_d_weight*self.GANLoss(pred_real_D, True) loss_dict['d_fake_loss'] = self.args.refine_d_weight*self.GANLoss(pred_fake_D, False) loss_dict['d_loss'] = loss_dict['d_real_loss'] + loss_dict['d_fake_loss'] else: g_loss = loss_dict['loss_all'] loss_dict['d_real_loss'] = 0*self.GANLoss(pred_real_D, True) loss_dict['d_fake_loss'] = 0*self.GANLoss(pred_fake_D, False) loss_dict['d_loss'] = loss_dict['d_real_loss'] + loss_dict['d_fake_loss'] self.sync(loss_dict) self.optG.zero_grad() g_loss.backward() self.optG.step() # discriminator backward pass self.optD.zero_grad() loss_dict['d_loss'].backward() self.optD.step() comp_time += time() - end end = time() if self.args.rank == 0: # add info to tensorboard info = {key:value.item() for key,value in loss_dict.items()} # add discriminator value pred_value = 0 real_value = 0 for i in range(self.args.num_D): pred_value += torch.mean(pred_fake_D[i][-1]) real_value += torch.mean(pred_real_D[i][-1]) pred_value/=self.args.num_D real_value/=self.args.num_D info["fake"] = pred_value.item() info["real"] = real_value.item() self.writer.add_scalars("losses", info, self.global_step) # print if self.step % self.args.disp_interval == 0: self.args.logger.info( 'Epoch [{epoch:d}/{tot_epoch:d}][{cur_batch:d}/{tot_batch:d}] ' 'load [{load_time:.3f}s] comp [{comp_time:.3f}s] ' 'loss [{loss:.4f}]'.format( epoch=self.epoch, tot_epoch=self.args.epochs, cur_batch=self.step+1, tot_batch=len(self.train_loader), load_time=load_time, comp_time=comp_time, loss=loss.item() ) ) comp_time = 0 load_time = 0 if self.step % 50 == 0: image_set = self.prepare_image_set(data, coarse_img.cpu(), [ refined_img.cpu() for refined_img in refined_imgs], seg.cpu(), \ pred_fake_D, pred_real_D) self.writer.add_image('image_{}'.format(self.global_step), image_set, self.global_step) def validate(self): self.args.logger.info('Validation epoch {} started'.format(self.epoch)) self.model.eval() val_criteria = { 'l1': AverageMeter(), 'psnr':AverageMeter(), 'ssim':AverageMeter(), 'iou':AverageMeter(), 'vgg':AverageMeter() } step_losses = OrderedDict() with torch.no_grad(): end = time() load_time = 0 comp_time = 0 for i, data in enumerate(self.val_loader): load_time += time()-end end = time() self.step=i # forward pass x, fg_mask, gt = self.get_input(data) size = x.size(0) x = x.cuda(self.args.rank, non_blocking=True) fg_mask = fg_mask.cuda(self.args.rank, non_blocking=True) gt = gt.cuda(self.args.rank, non_blocking=True) coarse_img, refined_imgs, seg, pred_fake_D, pred_real_D= self.model(x, fg_mask, gt) # rgb criteria step_losses['l1'] = self.L1Loss(refined_imgs[-1], gt[:,:3]) step_losses['psnr'] = self.PSNRLoss((refined_imgs[-1]+1)/2, (gt[:,:3]+1)/2) step_losses['ssim'] = 1-self.SSIMLoss(refined_imgs[-1], gt[:,:3]) step_losses['iou'] = self.IoULoss(torch.argmax(seg, dim=1), torch.argmax(gt[:,3:], dim=1)) step_losses['vgg'] = self.VGGCosLoss(refined_imgs[-1], gt[:, :3], False) self.sync(step_losses) # sum for key in list(val_criteria.keys()): val_criteria[key].update(step_losses[key].cpu().item(), size*self.args.gpus) if self.args.syn_type == 'extra': # not implemented imgs = [] segs = [] img = img[0].unsqueeze(0) seg = seg[0].unsqueeze(0) x = x[0].unsqueeze(0) for i in range(self.args.extra_length): if i!=0: x = torch.cat([x[:,3:6], img, x[:, 26:46], seg_fil], dim=1).cuda(self.args.rank, non_blocking=True) img, seg = self.model(x) seg_fil = torch.argmax(seg, dim=1) seg_fil = transform_seg_one_hot(seg_fil, 20, cuda=True)*2-1 imgs.append(img) segs.append(seg_fil) comp_time += time() - end end = time() # print if self.args.rank == 0: if self.step % self.args.disp_interval == 0: self.args.logger.info( 'Epoch [{epoch:d}][{cur_batch:d}/{tot_batch:d}] ' 'load [{load_time:.3f}s] comp [{comp_time:.3f}s]'.format( epoch=self.epoch, cur_batch=self.step+1, tot_batch=len(self.val_loader), load_time=load_time, comp_time=comp_time ) ) comp_time = 0 load_time = 0 if self.step % 3 == 0: if self.args.syn_type == 'inter': image_set = self.prepare_image_set(data, coarse_img.cpu(), [ refined_img.cpu() for refined_img in refined_imgs], seg.cpu(), \ pred_fake_D, pred_real_D) else: image_set = self.prepare_image_set(data, imgs, segs, True) image_name = 'e{}_img_{}'.format(self.epoch, self.step) self.writer.add_image(image_name, image_set, self.step) if self.args.rank == 0: self.args.logger.info( 'Epoch [{epoch:d}] \n \ L1\t: {l1:.4f} \n\ PSNR\t: {psnr:.4f} \n\ SSIM\t: {ssim:.4f} \n\ IoU\t: {iou:.4f} \n\ vgg\t: {vgg:.4f}\n'.format( epoch=self.epoch, l1=val_criteria['l1'].avg, psnr=val_criteria['psnr'].avg, ssim=val_criteria['ssim'].avg, iou=val_criteria['iou'].avg, vgg = val_criteria['vgg'].avg ) ) tfb_info = {key:value.avg for key,value in val_criteria.items()} self.writer.add_scalars('val/score', tfb_info, self.epoch) def test(self): self.args.logger.info('testing started') self.model.eval() with torch.no_grad(): end = time() load_time = 0 comp_time = 0 img_count = 0 for i, data in enumerate(self.val_loader): load_time += time()-end end = time() self.step=i # forward pass x, fg_mask, gt = self.get_input(data) size = x.size(0) x = x.cuda(self.args.rank, non_blocking=True) fg_mask = fg_mask.cuda(self.args.rank, non_blocking=True) gt = gt.cuda(self.args.rank, non_blocking=True) img, seg = self.model(x, fg_mask) bs = img.size(0) for i in range(bs): pred_img = self.normalize(img[i]) gt_img = self.normalize(gt[i, :3]) save_img(pred_img, '{}/{}_pred.png'.format(self.args.save_dir, img_count)) save_img(gt_img, '{}/{}_gt.png'.format(self.args.save_dir, img_count)) img_count+=1 comp_time += time() - end end = time() # print if self.args.rank == 0: if self.step % self.args.disp_interval == 0: self.args.logger.info( 'img [{}] load [{load_time:.3f}s] comp [{comp_time:.3f}s]'.format(img_count, load_time=load_time, comp_time=comp_time ) ) comp_time = 0 load_time = 0 def sync(self, loss_dict, mean=True): '''Synchronize all tensors given using mean or sum.''' for tensor in loss_dict.values(): dist.all_reduce(tensor) if mean: tensor.div_(self.args.gpus) def save_checkpoint(self): save_md_dir = '{}_{}_{}_{}'.format(self.args.model, self.args.mode, self.args.syn_type, self.args.session) save_name = os.path.join(self.args.path, 'checkpoint', save_md_dir + '_{}_{}.pth'.format(self.epoch, self.step)) self.args.logger.info('Saving checkpoint..') torch.save({ 'session': self.args.session, 'epoch': self.epoch + 1, 'model': self.model.module.state_dict(), 'optG': self.optG.state_dict(), 'optD': self.optD.state_dict() }, save_name) self.args.logger.info('save model: {}'.format(save_name)) def load_checkpoint(self): load_md_dir = '{}_{}_{}_{}'.format("RefineNet", self.args.mode, self.args.syn_type, self.args.checksession) if self.args.load_dir is not None: load_name = os.path.join(self.args.load_dir, 'checkpoint', load_md_dir+'_{}_{}.pth'.format(self.args.checkepoch, self.args.checkpoint)) else: load_name = os.path.join(load_md_dir+'_{}_{}.pth'.format(self.args.checkepoch, self.args.checkpoint)) self.args.logger.info('Loading checkpoint %s' % load_name) ckpt = torch.load(load_name) if self.args.lock_coarse: model_dict = self.model.module.state_dict() new_ckpt = OrderedDict() for key,item in ckpt['model'].items(): if 'coarse' in key: new_ckpt[key] = item model_dict.update(new_ckpt) self.model.module.load_state_dict(model_dict) else: self.model.module.load_state_dict(ckpt['model']) # transfer opt params to current device if not self.args.lock_coarse: if not self.args.val : self.optimizer.load_state_dict(ckpt['optimizer']) self.epoch = ckpt['epoch'] self.global_step = (self.epoch-1)*len(self.train_loader) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda(self.args.rank) else : assert ckpt['epoch']-1 == self.args.checkepoch, [ckpt['epoch'], self.args.checkepoch] self.epoch = ckpt['epoch'] - 1 self.args.logger.info('checkpoint loaded')
def main(): opt = get_model_config() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(opt) # Model setting logger.info('Build Model') generator = define_G(3, 3, opt.ngf).to(device) total_param = sum([p.numel() for p in generator.parameters()]) logger.info(f'Generator size: {total_param} tensors') discriminator = define_D(3 + 3, opt.ndf, opt.disc).to(device) total_param = sum([p.numel() for p in discriminator.parameters()]) logger.info(f'Discriminator size: {total_param} tensors') if torch.cuda.device_count() > 1: logger.info(f"Let's use {torch.cuda.device_count()} GPUs!") generator = DataParallel(generator) discriminator = DataParallel(discriminator) if opt.mode == 'train': dirname = datetime.now().strftime("%m%d%H%M") + f'_{opt.name}' log_dir = os.path.join('./experiments', dirname) os.makedirs(log_dir, exist_ok=True) logger.info(f'LOG DIR: {log_dir}') # Dataset setting logger.info('Set the dataset') image_size: Tuple[int] = (opt.image_h, opt.image_w) train_transform, val_transform = get_transforms( image_size, augment_type=opt.augment_type, image_norm=opt.image_norm) trainset = TrainDataset(image_dir=os.path.join(opt.data_dir, 'train'), transform=train_transform) valset = TrainDataset(image_dir=os.path.join(opt.data_dir, 'val'), transform=val_transform) train_loader = DataLoader(dataset=trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) val_loader = DataLoader(dataset=valset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers) # Loss setting criterion = {} criterion['gan'] = GANLoss(use_lsgan=True).to(device) criterion['l1'] = torch.nn.L1Loss().to(device) # Optimizer setting g_optimizer = get_optimizer(generator.parameters(), opt.optimizer, opt.lr, opt.weight_decay) d_optimizer = get_optimizer(discriminator.parameters(), opt.optimizer, opt.lr, opt.weight_decay) logger.info( f'Initial Learning rate(G): {g_optimizer.param_groups[0]["lr"]:.6f}' ) logger.info( f'Initial Learning rate(D): {d_optimizer.param_groups[0]["lr"]:.6f}' ) # Scheduler setting g_scheduler = get_scheduler(g_optimizer, opt.scheduler, opt) d_scheduler = get_scheduler(d_optimizer, opt.scheduler, opt) # Tensorboard setting writer = SummaryWriter(log_dir=log_dir) logger.info('Start to train!') train_process(opt, generator, discriminator, criterion, g_optimizer, d_optimizer, g_scheduler, d_scheduler, train_loader=train_loader, val_loader=val_loader, log_dir=log_dir, writer=writer, device=device) # TODO: write inference code elif opt.mode == 'test': logger.info(f'Model loaded from {opt.checkpoint}') model.eval() logger.info('Start to test!') test_status = inference(model=model, test_loader=test_loader, device=device, criterion=criterion)
def train(opt): print("train...") #Create results directories os.makedirs(f'{opt.result_imgs_path}/{opt.dataset_name}-{opt.version}', exist_ok=True) os.makedirs(f'{opt.result_models_path}/{opt.dataset_name}-{opt.version}', exist_ok=True) # Losses mixed_loss = MixedLoss() l1_loss = nn.L1Loss() l2_loss = nn.MSELoss() loss_gan = GANLoss("lsgan") val_ssim = SSIM() #GPU device = 'cuda' if torch.cuda.is_available() else 'cpu' torch.manual_seed(777) if device == 'cuda': torch.cuda.manual_seed_all(777) Tensor = torch.cuda.FloatTensor else: Tensor = torch.FloatTensor #Modle Initialize G_AB = Generator().to(device) Dis = Discriminator().to(device) #Loss Initialize l1_loss = l1_loss.to(device) l2_loss = l2_loss.to(device) mixed_loss = mixed_loss.to(device) loss_gan = loss_gan.to(device) #Load Pre-trained Models if opt.epoch != 0: G_AB.load_state_dict( torch.load( f'{opt.result_models_path}/{opt.dataset_name}-{opt.version}/G_AB_{opt.epoch:0>4}.pth' )) Dis.load_state_dict( torch.load( f'{opt.result_models_path}/{opt.dataset_name}-{opt.version}/Dis_{opt.epoch:0>4}.pth' )) # Initialize weights else: G_AB.apply(weights_init_normal) Dis.apply(weights_init_normal) # Optimizers optimizer_G = torch.optim.Adam(G_AB.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(Dis.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) # Learning rate update schedulers lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR( optimizer_D, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) # Buffers of previously generated samples fake_A_buffer = ReplayBuffer() fake_B_buffer = ReplayBuffer() # Image transformations transforms_ = [ transforms.Resize((opt.img_height, opt.img_height), Image.BICUBIC), # transforms.RandomCrop((opt.img_height, opt.img_width)), # transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] # validation data loading val_dataloader = DataLoader(valDataset(f'data/{opt.dataset_name}', transforms_=transforms_, mode='val'), batch_size=6, shuffle=True, num_workers=1) # real rainy data loading real_dataset = DataLoader(RealDataset(f'data/{opt.dataset_name}', transforms_=transforms_, mode='test'), batch_size=6, shuffle=True, num_workers=1) prev_time = time.time() for epoch in range(opt.epoch, opt.n_epochs): dataloader = DataLoader( train_dataset( f'data/{opt.dataset_name}/training', transforms_=transforms_, # rand=1, mode='train'), batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu) for i, batch in enumerate(dataloader): # Set model input real_A = Variable(batch['A'].type(Tensor)) real_B = Variable(batch['B'].type(Tensor)) # ------------------ # Train Generators # ------------------ optimizer_G.zero_grad() gen = G_AB(real_A) (a, b, c, d, d_fake_b) = Dis(gen) (A, B, C, D, d_real_b) = Dis(real_B) loss_p = [] g_adv = [] for j, (q, p) in enumerate( zip((a, b, c, d, d_fake_b), (A, B, C, D, d_real_b))): p_ = Variable(p, requires_grad=False) # perceptual loss bat, ch, h, w = q.size() loss_p.append(l1_loss(q, p_) * 10) g_loss = loss_gan(torch.cat((d_fake_b, d_real_b), 1), True) mixed = mixed_loss(gen, real_B) loss_perc = torch.mean(torch.stack(loss_p)) loss_G = 10 * mixed + 10 * g_loss + loss_perc # + 20 * st_loss loss_G.backward() optimizer_G.step() # ----------------------- # Train Discriminator # ----------------------- optimizer_D.zero_grad() gen_ = fake_B_buffer.push_and_pop(gen) (q, w, e, r, D_fake) = Dis(gen_) (z, x, y, u, D_real) = Dis(real_B) fake_list = [] real_list = [] loss_pp = [] for j, (g, n) in enumerate( zip((q, w, e, r, D_fake), (z, x, y, u, D_real))): n_ = Variable(n, requires_grad=False) loss_pp.append(l1_loss(g, n_)) loss_fake = loss_gan(torch.cat((D_fake, D_real), 1), False) loss_real = loss_gan(torch.cat((D_real, D_real), 1), True) loss_pa = torch.mean(torch.stack(loss_pp)) if loss_pa > opt.margin: loss_pa = 0 else: loss_pa = opt.margin - loss_pa # Total loss loss_D = (loss_real + loss_fake) * 0.5 + loss_pa loss_D.backward() optimizer_D.step() # -------------- # Log Progress # -------------- # Determine approximate time left batches_done = epoch * len(dataloader) + i batches_left = opt.n_epochs * len(dataloader) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log sys.stdout.write("\r[Epoch %d/%d] [Batch %d/%d] " "[G_loss: %f, " # "g_loss: %f, " "Adv_loss: %f, " # "mixed: %f, " "GPU Memory Usage: %d, " "lr: %s, " "ETA: %s]" % ( epoch, opt.n_epochs, i, len(dataloader), loss_G.item(), # g_loss.item(), loss_D.item(), # mixed.item(), (torch.cuda.memory_allocated() / 1024) / 1024, lr_scheduler_G.get_lr(), time_left)) # If at sample interval save image if batches_done % opt.sample_interval == 0: imgs = next(iter(val_dataloader)) A = imgs['A'].type(Tensor) generated = G_AB(A) B = imgs['B'].type(Tensor) ssim = val_ssim(B, generated) # PSNR mse = nn.MSELoss() mm = mse(generated, B) pp = 10 * log10(1 / mm.item()) real_test = next(iter(real_dataset)) real_rain = real_test['A'].type(Tensor) generated_real_snow = G_AB(real_rain) img_sample = torch.cat( (real_rain.data, generated_real_snow.data), 0) save_image( img_sample, f'{opt.result_imgs_path}/{opt.dataset_name}-{opt.version}/{epoch:0>4}_{batches_done:0>4}-ssim_{ssim.item():0.3f} psnr_{pp:0.3f}.png', nrow=6, normalize=True) if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0: # Save model checkpoints torch.save( G_AB.state_dict(), f'{opt.result_models_path}/{opt.dataset_name}-{opt.version}/G_AB_{epoch:0>4}.pth' ) torch.save( Dis.state_dict(), f'{opt.result_models_path}/{opt.dataset_name}-{opt.version}/Dis_{epoch:0>4}.pth' ) # Update learning rates lr_scheduler_G.step() lr_scheduler_D.step()
def train(self): """ Train UEGAN .""" self.fetcher = InputFetcher(self.loaders.ref) self.fetcher_val = InputFetcher(self.loaders.val) self.train_steps_per_epoch = len(self.loaders.ref) self.model_save_step = int(self.args.model_save_epoch * self.train_steps_per_epoch) # set nima, psnr, ssim global parameters if self.args.is_test_nima: self.best_nima_epoch, self.best_nima = 0, 0.0 if self.args.is_test_psnr_ssim: self.best_psnr_epoch, self.best_psnr = 0, 0.0 self.best_ssim_epoch, self.best_ssim = 0, 0.0 # set loss functions self.criterionPercep = PerceptualLoss() self.criterionIdt = MultiscaleRecLoss( scale=3, rec_loss_type=self.args.idt_loss_type, multiscale=True) self.criterionGAN = GANLoss(self.args.adv_loss_type, tensor=torch.cuda.FloatTensor) # start from scratch or trained models if self.args.pretrained_model: start_step = int(self.args.pretrained_model * self.train_steps_per_epoch) self.load_pretrained_model(self.args.pretrained_model) else: start_step = 0 # start training print( "======================================= start training =======================================" ) self.start_time = time.time() total_steps = int(self.args.total_epochs * self.train_steps_per_epoch) self.val_start_steps = int(self.args.num_epochs_start_val * self.train_steps_per_epoch) self.val_each_steps = int(self.args.val_each_epochs * self.train_steps_per_epoch) print( "=========== start to iteratively train generator and discriminator ===========" ) pbar = tqdm(total=total_steps, desc='Train epoches', initial=start_step) for step in range(start_step, total_steps): ########## model train self.G.train() self.D.train() ########## data iter input = next(self.fetcher) self.real_raw, self.real_exp, self.real_raw_name = input.img_raw, input.img_exp, input.img_name ########## forward self.fake_exp = self.G(self.real_raw) self.fake_exp_store = self.fake_exp_pool.query(self.fake_exp) ########## update D self.d_optimizer.zero_grad() real_exp_preds = self.D(self.real_exp) fake_exp_preds = self.D(self.fake_exp_store.detach()) d_loss = self.criterionGAN(real_exp_preds, fake_exp_preds, None, None, for_discriminator=True) if self.args.adv_input: input_preds = self.D(self.real_raw) d_loss += self.criterionGAN(real_exp_preds, input_preds, None, None, for_discriminator=True) d_loss.backward() self.d_optimizer.step() self.d_loss = d_loss.item() ########## update G self.g_optimizer.zero_grad() real_exp_preds = self.D(self.real_exp) fake_exp_preds = self.D(self.fake_exp) g_adv_loss = self.args.lambda_adv * self.criterionGAN( real_exp_preds, fake_exp_preds, None, None, for_discriminator=False) self.g_adv_loss = g_adv_loss.item() g_loss = g_adv_loss g_percep_loss = self.args.lambda_percep * self.criterionPercep( (self.fake_exp + 1.) / 2., (self.real_raw + 1.) / 2.) self.g_percep_loss = g_percep_loss.item() g_loss += g_percep_loss self.real_exp_idt = self.G(self.real_exp) g_idt_loss = self.args.lambda_idt * self.criterionIdt( self.real_exp_idt, self.real_exp) self.g_idt_loss = g_idt_loss.item() g_loss += g_idt_loss g_loss.backward() self.g_optimizer.step() self.g_loss = g_loss.item() ### print info and save models self.print_info(step, total_steps, pbar) ### logging using tensorboard self.logging(step) ### validation self.model_validation(step) ### learning rate update if step % self.train_steps_per_epoch == 0: current_epoch = step // self.train_steps_per_epoch self.lr_scheduler_g.step(epoch=current_epoch) self.lr_scheduler_d.step(epoch=current_epoch) for param_group in self.g_optimizer.param_groups: pbar.write( "====== Epoch: {:>3d}/{}, Learning rate(lr) of Encoder(E) and Generator(G): [{}], " .format(((step + 1) // self.train_steps_per_epoch), self.args.total_epochs, param_group['lr']), end='') for param_group in self.d_optimizer.param_groups: pbar.write( "Learning rate (lr) of Discriminator(D): [{}] ======". format(param_group['lr'])) pbar.update(1) pbar.set_description(f"Train epoch %.2f" % ((step + 1.0) / self.train_steps_per_epoch)) self.val_best_results() pbar.write("=========== Complete training ===========") pbar.close()
# load dataset and data loader transform = transforms.Compose([ transforms.Resize(64), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) dataset = datasets.MNIST('.', transform=transform, download=True) dataloader = data.DataLoader(dataset, batch_size=4) # model g = Generator() d = Discriminator() # losses gan_loss = GANLoss() # use is_cuda = torch.cuda.is_available() if is_cuda: g = g.cuda() d = d.cuda() # optimizer optim_G = optim.Adam(g.parameters()) optim_D = optim.Adam(d.parameters()) # train for epoch in range(num_epoch): total_batch = len(dataloader)
def train(args): # check if results path exists, if not create the folder check_folder(args.results_path) # generator model generator = HourglassNet(high_res=args.high_resolution) generator.to(device) # discriminator model discriminator = Discriminator(input_nc=1) discriminator.to(device) # optimizer optimizer_g = torch.optim.Adam(generator.parameters()) optimizer_d = torch.optim.Adam(discriminator.parameters()) # training parameters feature_weight = 0.5 skip_count = 0 use_gan = args.use_gan print_frequency = 5 # dataloader illum_dataset = IlluminationDataset() illum_dataloader = DataLoader(illum_dataset, batch_size=args.batch_size) # gan loss based on lsgan that uses squared error gan_loss = GANLoss(gan_mode='lsgan') # training for epoch in range(1, args.epochs + 1): for data_idx, data in enumerate(illum_dataloader): source_img, source_light, target_img, target_light = data source_img.to(device) source_light.to(device) target_img.to(device) target_light.to(device) optimizer_g.zero_grad() # if skip connections are required for training, else skip the # connections based on the the training scheme for low-res/high-res # images if args.use_skip: skip_count = 0 else: skip_count = 5 if args.high_resolution else 4 output = generator(source_img, target_light, skip_count, target_img) source_face_feats, source_light_pred, target_face_feats, source_relit_pred = output img_loss = image_and_light_loss(source_relit_pred, target_img, source_light_pred, target_light) feat_loss = feature_loss(source_face_feats, target_face_feats) # if gan loss is used if use_gan: g_loss = gan_loss(discriminator(source_relit_pred), target_is_real=True) else: g_loss = torch.Tensor([0]) total_g_loss = img_loss + g_loss + (feature_weight * feat_loss) total_g_loss.backward() optimizer_g.step() # training the discriminator if use_gan: optimizer_d.zero_grad() pred_real = discriminator(target_img) pred_fake = discriminator(source_relit_pred.detach()) loss_real = gan_loss(pred_real, target_is_real=True) loss_fake = gan_loss(pred_fake, target_is_real=False) d_loss = (loss_real + loss_fake) * 0.5 d_loss.backward() optimizer_d.step() else: loss_real = torch.Tensor([0]) loss_fake = torch.Tensor([0]) if data_idx % print_frequency == 0: print( "Epoch: [{}]/[{}], Iteration: [{}]/[{}], image loss: {}, feature loss: {}, gen fake loss: {}, dis real loss: {}, dis fake loss: {}" .format(epoch, args.epochs + 1, data_idx + 1, len(illum_dataloader), img_loss.item(), feat_loss.item(), g_loss.item(), loss_real.item(), loss_fake.item())) # saving model checkpoint_path = os.path.join(args.results_path, 'checkpoint_epoch_{}.pth'.format(epoch)) checkpoint = { 'generator': generator.state_dict(), 'discriminator': discriminator.state_dict(), 'optimizer_g': optimizer_g.state_dict(), 'optimizer_d': optimizer_d.state_dict() } torch.save(checkpoint, checkpoint_path)