Пример #1
0
    def __init__(self, args):
        ##The Network
        self.g_AB = Generator(args.n_mel_channels, args.ngf,
                              args.n_residual_layers).cuda(
                              )  #initialise generator with n mel channels

        self.g_BA = Generator(args.n_mel_channels, args.ngf,
                              args.n_residual_layers).cuda(
                              )  #initialise generator with n mel channels
        self.Da = Discriminator(
            args.num_D, args.ndf, args.n_layers_D,
            args.downsamp_factor).cuda()  #initialize discriminator
        self.Db = Discriminator(
            args.num_D, args.ndf, args.n_layers_D,
            args.downsamp_factor).cuda()  #initialize discriminator

        self.fft = Audio2Mel(n_mel_channels=args.n_mel_channels).cuda()

        #The Losses
        self.MSE = nn.MSELoss()
        self.L1 = nn.L1Loss()

        # Optimizers
        #####################################################
        self.g_optimizer = torch.optim.Adam(itertools.chain(
            self.g_AB.parameters(), self.g_BA.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(itertools.chain(
            self.Da.parameters(), self.Db.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))

        self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.g_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
        self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.d_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)

        #Load potential checkpoint
        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)

        try:
            ckpt = utils.load_checkpoint('%slatest.ckpt' %
                                         (args.checkpoint_dir))
            self.start_epoch = ckpt['epoch']
            self.Da.load_state_dict(ckpt['Da'])
            self.Db.load_state_dict(ckpt['Db'])
            self.g_AB.load_state_dict(ckpt['Gab'])
            self.g_BA.load_state_dict(ckpt['Gba'])
            self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
            self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
        except:
            print(' [*] No checkpoint!')
            self.start_epoch = 0
Пример #2
0
    def __init__(self, args):

        self.Gab = define_Gen(input_nc=3,
                              output_nc=3,
                              netG=args.gen_net,
                              gpu_ids=args.gpu_ids)
        self.Gba = define_Gen(input_nc=3,
                              output_nc=3,
                              netG=args.gen_net,
                              gpu_ids=args.gpu_ids)
        self.Da = define_Dis(input_nc=3, gpu_ids=args.gpu_ids)
        self.Db = define_Dis(input_nc=3, gpu_ids=args.gpu_ids)

        self.MSE = nn.MSELoss()
        self.L1 = nn.L1Loss()

        self.g_optimizer = torch.optim.Adam(itertools.chain(
            self.Gab.parameters(), self.Gba.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(itertools.chain(
            self.Da.parameters(), self.Db.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))

        self.g_lr_scheduler = lr_scheduler.LambdaLR(self.g_optimizer,
                                                    lr_lambda=utils.LambdaLR(
                                                        args.epochs, 0,
                                                        args.decay_epoch).step)
        self.d_lr_scheduler = lr_scheduler.LambdaLR(self.d_optimizer,
                                                    lr_lambda=utils.LambdaLR(
                                                        args.epochs, 0,
                                                        args.decay_epoch).step)

        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)

        ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir))
        self.start_epoch = ckpt['epoch']
        self.Da.load_state_dict(ckpt['Da'])
        self.Db.load_state_dict(ckpt['Db'])
        self.Gab.load_state_dict(ckpt['Gab'])
        self.Gba.load_state_dict(ckpt['Gba'])
        self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
        self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
Пример #3
0
    def __init__(self, args):

        # Define the network
        #####################################################
        self.Gab = define_Gen(input_nc=3,
                              output_nc=3,
                              ngf=args.ngf,
                              netG=args.gen_net,
                              norm=args.norm,
                              use_dropout=not args.no_dropout,
                              gpu_ids=args.gpu_ids)
        self.Gba = define_Gen(input_nc=3,
                              output_nc=3,
                              ngf=args.ngf,
                              netG=args.gen_net,
                              norm=args.norm,
                              use_dropout=not args.no_dropout,
                              gpu_ids=args.gpu_ids)
        self.Da = define_Dis(input_nc=3,
                             ndf=args.ndf,
                             netD=args.dis_net,
                             n_layers_D=3,
                             norm=args.norm,
                             gpu_ids=args.gpu_ids)
        self.Db = define_Dis(input_nc=3,
                             ndf=args.ndf,
                             netD=args.dis_net,
                             n_layers_D=3,
                             norm=args.norm,
                             gpu_ids=args.gpu_ids)

        utils.print_networks([self.Gab, self.Gba, self.Da, self.Db],
                             ['Gab', 'Gba', 'Da', 'Db'])

        # Define Loss criterias

        self.MSE = nn.MSELoss()
        self.L1 = nn.L1Loss()

        # Optimizers
        #####################################################
        self.g_optimizer = torch.optim.Adam(itertools.chain(
            self.Gab.parameters(), self.Gba.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(itertools.chain(
            self.Da.parameters(), self.Db.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))

        self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.g_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
        self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.d_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)

        # Try loading checkpoint
        #####################################################
        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)

        try:
            ckpt = utils.load_checkpoint('%s/latest.ckpt' %
                                         (args.checkpoint_dir))
            self.start_epoch = ckpt['epoch']
            self.Da.load_state_dict(ckpt['Da'])
            self.Db.load_state_dict(ckpt['Db'])
            self.Gab.load_state_dict(ckpt['Gab'])
            self.Gba.load_state_dict(ckpt['Gba'])
            self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
            self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
        except:
            print(' [*] No checkpoint!')
            self.start_epoch = 0
Пример #4
0
	def __init__(self, args):
		##The Network
		self.g_AB =  Generator(conv_dim=args.ngf, n_res_blocks=2).cuda() #initialise generator with n mel channels
		
		self.g_BA =  Generator(conv_dim=args.ngf, n_res_blocks=2).cuda() #initialise generator with n mel channels
		self.Da = Discriminator(conv_dim=args.ndf).cuda() #initialize discriminator
		self.Db =  Discriminator(conv_dim=args.ndf).cuda() #initialize discriminator

		#Loss type: (Wasserstein or Basic)
		self.loss_type = args.loss
		
		#The Losses
		self.MSE = nn.MSELoss()
		self.L1 = nn.L1Loss()

		# Optimizers
		#####################################################
		self.g_optimizer = torch.optim.Adam(itertools.chain(self.g_AB.parameters(),self.g_BA.parameters()), lr=args.lr, betas=(0.5, 0.99))
		self.d_optimizer = torch.optim.Adam(itertools.chain(self.Da.parameters(),self.Db.parameters()), lr=args.lr, betas=(0.5, 0.99))
		

		self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.g_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
		self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.d_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)

		
		#Load potential checkpoint
		if not os.path.isdir(args.checkpoint_dir):
			os.makedirs(args.checkpoint_dir)

		try:
			ckpt = utils.load_checkpoint('./checkpoints/w_gan_3.ckpt')
			self.start_epoch = ckpt['epoch']
			self.Da.load_state_dict(ckpt['Da'])
			self.Db.load_state_dict(ckpt['Db'])
			self.g_AB.load_state_dict(ckpt['Gab'])
			self.g_BA.load_state_dict(ckpt['Gba'])
			self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
			self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
		except:
			print(' [*] No checkpoint!')
			self.start_epoch = 0



	#Criterion for WGAN
		self.criterionGAN = WassersteinGANLoss()
		self.wgan_n_critic = 10
		self.wgan_clamp_lower = 0.01
		self.wgan_clamp_upper = 0.01

	#Load audio data
		logf0s_normalization = np.load("./cache_check/logf0s_normalization.npz")
		self.log_f0s_mean_A = logf0s_normalization['mean_A']
		self.log_f0s_std_A = logf0s_normalization['std_A']
		self.log_f0s_mean_B = logf0s_normalization['mean_B']
		self.log_f0s_std_B = logf0s_normalization['std_B']

		mcep_normalization = np.load("./cache_check/mcep_normalization.npz")
		self.coded_sps_A_mean = mcep_normalization['mean_A']
		self.coded_sps_A_std = mcep_normalization['std_A']
		self.coded_sps_B_mean = mcep_normalization['mean_B']
		self.coded_sps_B_std = mcep_normalization['std_B']
    def __init__(self, args):

        if args.dataset == 'voc2012':
            self.n_channels = 21
        elif args.dataset == 'cityscapes':
            self.n_channels = 20
        elif args.dataset == 'acdc':
            self.n_channels = 4

        # Define the network
        #####################################################
        # for segmentaion to image
        self.Gis = define_Gen(input_nc=self.n_channels,
                              output_nc=3,
                              ngf=args.ngf,
                              netG='deeplab',
                              norm=args.norm,
                              use_dropout=not args.no_dropout,
                              gpu_ids=args.gpu_ids)
        # for image to segmentation
        self.Gsi = define_Gen(input_nc=3,
                              output_nc=self.n_channels,
                              ngf=args.ngf,
                              netG='deeplab',
                              norm=args.norm,
                              use_dropout=not args.no_dropout,
                              gpu_ids=args.gpu_ids)
        self.Di = define_Dis(input_nc=3,
                             ndf=args.ndf,
                             netD='pixel',
                             n_layers_D=3,
                             norm=args.norm,
                             gpu_ids=args.gpu_ids)
        self.Ds = define_Dis(
            input_nc=self.n_channels,
            ndf=args.ndf,
            netD='pixel',
            n_layers_D=3,
            norm=args.norm,
            gpu_ids=args.gpu_ids)  # for voc 2012, there are 21 classes

        self.old_Gis = define_Gen(input_nc=self.n_channels,
                                  output_nc=3,
                                  ngf=args.ngf,
                                  netG='resnet_9blocks',
                                  norm=args.norm,
                                  use_dropout=not args.no_dropout,
                                  gpu_ids=args.gpu_ids)
        self.old_Gsi = define_Gen(input_nc=3,
                                  output_nc=self.n_channels,
                                  ngf=args.ngf,
                                  netG='resnet_9blocks_softmax',
                                  norm=args.norm,
                                  use_dropout=not args.no_dropout,
                                  gpu_ids=args.gpu_ids)
        self.old_Di = define_Dis(input_nc=3,
                                 ndf=args.ndf,
                                 netD='pixel',
                                 n_layers_D=3,
                                 norm=args.norm,
                                 gpu_ids=args.gpu_ids)

        ### To put the pretrained weights in Gis and Gsi
        # if args.dataset != 'acdc':
        #     saved_state_dict = torch.load(pretrained_loc)
        #     new_params_Gsi = self.Gsi.state_dict().copy()
        #     # new_params_Gis = self.Gis.state_dict().copy()
        #     for name, param in new_params_Gsi.items():
        #         # print(name)
        #         if name in saved_state_dict and param.size() == saved_state_dict[name].size():
        #             new_params_Gsi[name].copy_(saved_state_dict[name])
        #             # print('copy {}'.format(name))
        #     self.Gsi.load_state_dict(new_params_Gsi)
        # for name, param in new_params_Gis.items():
        #     # print(name)
        #     if name in saved_state_dict and param.size() == saved_state_dict[name].size():
        #         new_params_Gis[name].copy_(saved_state_dict[name])
        #         # print('copy {}'.format(name))
        # # self.Gis.load_state_dict(new_params_Gis)

        ### This is just so as to get pretrained methods for the case of Gis
        if args.dataset == 'voc2012':
            try:
                ckpt_for_Arnab_loss = utils.load_checkpoint(
                    './ckpt_for_Arnab_loss.ckpt')
                self.old_Gis.load_state_dict(ckpt_for_Arnab_loss['Gis'])
                self.old_Gsi.load_state_dict(ckpt_for_Arnab_loss['Gsi'])
            except:
                print(
                    '**There is an error in loading the ckpt_for_Arnab_loss**')

        utils.print_networks([self.Gsi], ['Gsi'])

        utils.print_networks([self.Gis, self.Gsi, self.Di, self.Ds],
                             ['Gis', 'Gsi', 'Di', 'Ds'])

        self.args = args

        ### interpolation
        self.interp = nn.Upsample((args.crop_height, args.crop_width),
                                  mode='bilinear',
                                  align_corners=True)

        self.MSE = nn.MSELoss()
        self.L1 = nn.L1Loss()
        self.CE = nn.CrossEntropyLoss()
        self.activation_softmax = nn.Softmax2d()
        self.activation_tanh = nn.Tanh()
        self.activation_sigmoid = nn.Sigmoid()

        ### Tensorboard writer
        self.writer_semisuper = SummaryWriter(tensorboard_loc + '_semisuper')
        self.running_metrics_val = utils.runningScore(self.n_channels,
                                                      args.dataset)

        ### For adding gaussian noise
        self.gauss_noise = utils.GaussianNoise(sigma=0.2)

        # Optimizers
        #####################################################
        self.g_optimizer = torch.optim.Adam(itertools.chain(
            self.Gis.parameters(), self.Gsi.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(itertools.chain(
            self.Di.parameters(), self.Ds.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))

        self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.g_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
        self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.d_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)

        # Try loading checkpoint
        #####################################################
        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)

        try:
            ckpt = utils.load_checkpoint('%s/latest_semisuper_cycleGAN.ckpt' %
                                         (args.checkpoint_dir))
            self.start_epoch = ckpt['epoch']
            self.Di.load_state_dict(ckpt['Di'])
            self.Ds.load_state_dict(ckpt['Ds'])
            self.Gis.load_state_dict(ckpt['Gis'])
            self.Gsi.load_state_dict(ckpt['Gsi'])
            self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
            self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
            self.best_iou = ckpt['best_iou']
        except:
            print(' [*] No checkpoint!')
            self.start_epoch = 0
            self.best_iou = -100
criterion_identity = torch.nn.L1Loss()

# optimizier
optimizer_G = torch.optim.Adam(itertools.chain(G.parameters(), F.parameters()),
                               lr=lr,
                               betas=(0.5, 0.999))
optimizer_Dx = torch.optim.Adam(Dx.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_Dy = torch.optim.Adam(Dy.parameters(), lr=lr, betas=(0.5, 0.999))

# tensor wrapper
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

# change lr according to the epoch
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G,
                                                   lr_lambda=utils.LambdaLR(
                                                       epochs, 0,
                                                       decay_epoch).step)
lr_scheduler_Dx = torch.optim.lr_scheduler.LambdaLR(optimizer_Dx,
                                                    lr_lambda=utils.LambdaLR(
                                                        epochs, 0,
                                                        decay_epoch).step)
lr_scheduler_Dy = torch.optim.lr_scheduler.LambdaLR(optimizer_Dy,
                                                    lr_lambda=utils.LambdaLR(
                                                        epochs, 0,
                                                        decay_epoch).step)

# replay buffer
fake_X_buffer = utils.ReplayBuffer()
fake_Y_buffer = utils.ReplayBuffer()

input_X = Tensor(batch_size, input_nc, image_size, image_size)
Пример #7
0
    def __init__(self, args):

        # Set up both gens and discs
        self.Gab = define_Gen(input_nc=3,
                              output_nc=3,
                              ngf=args.ngf,
                              netG=args.gen_net,
                              norm=args.norm,
                              use_dropout=args.use_dropout,
                              gpu_ids=args.gpu_ids,
                              self_attn=args.self_attn,
                              spectral=args.spectral)
        self.Gba = define_Gen(input_nc=3,
                              output_nc=3,
                              ngf=args.ngf,
                              netG=args.gen_net,
                              norm=args.norm,
                              use_dropout=args.use_dropout,
                              gpu_ids=args.gpu_ids,
                              self_attn=args.self_attn,
                              spectral=args.spectral)

        self.Da = define_Dis(input_nc=3,
                             ndf=args.ndf,
                             netD=args.dis_net,
                             n_layers_D=3,
                             norm=args.norm,
                             gpu_ids=args.gpu_ids,
                             spectral=args.spectral,
                             self_attn=args.self_attn)
        self.Db = define_Dis(input_nc=3,
                             ndf=args.ndf,
                             netD=args.dis_net,
                             n_layers_D=3,
                             norm=args.norm,
                             gpu_ids=args.gpu_ids,
                             spectral=args.spectral,
                             self_attn=args.self_attn)

        utils.print_networks([self.Gab, self.Gba, self.Da, self.Db],
                             ['Gab', 'Gba', 'Da', 'Db'])

        # Loss functions
        self.MSE = nn.MSELoss()
        self.L1 = nn.L1Loss()
        self.ssim = kornia.losses.SSIM(11, reduction='mean')

        # Optimizers
        self.g_optimizer = torch.optim.Adam(itertools.chain(
            self.Gab.parameters(), self.Gba.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(itertools.chain(
            self.Da.parameters(), self.Db.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))

        self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.g_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
        self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.d_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)

        # Checkpoints
        if not os.path.isdir(args.checkpoint_path):
            os.makedirs(args.checkpoint_path)

        try:
            ckpt = utils.load_checkpoint('%s/latest.ckpt' %
                                         (args.checkpoint_path))
            self.start_epoch = ckpt['epoch']
            self.Da.load_state_dict(ckpt['Da'])
            self.Db.load_state_dict(ckpt['Db'])
            self.Gab.load_state_dict(ckpt['Gab'])
            self.Gba.load_state_dict(ckpt['Gba'])
            self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
            self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
        except:
            print(' [*] No checkpoint!')
            self.start_epoch = 0
Пример #8
0
    def __init__(self, args):

        if args.dataset == 'voc2012':
            self.n_channels = 21
        elif args.dataset == 'cityscapes':
            self.n_channels = 20
        elif args.dataset == 'acdc':
            self.n_channels = 4

        # Define the network
        #####################################################
        # for segmentaion to image
        self.Gis = define_Gen(input_nc=self.n_channels,
                              output_nc=3,
                              ngf=args.ngf,
                              netG='resnet_9blocks',
                              norm=args.norm,
                              use_dropout=not args.no_dropout,
                              gpu_ids=args.gpu_ids)
        # for image to segmentation
        self.Gsi = define_Gen(input_nc=3,
                              output_nc=self.n_channels,
                              ngf=args.ngf,
                              netG='resnet_9blocks_softmax',
                              norm=args.norm,
                              use_dropout=not args.no_dropout,
                              gpu_ids=args.gpu_ids)
        self.Di = define_Dis(input_nc=3,
                             ndf=args.ndf,
                             netD='pixel',
                             n_layers_D=3,
                             norm=args.norm,
                             gpu_ids=args.gpu_ids)
        self.Ds = define_Dis(
            input_nc=1,
            ndf=args.ndf,
            netD='pixel',
            n_layers_D=3,
            norm=args.norm,
            gpu_ids=args.gpu_ids)  # for voc 2012, there are 21 classes

        utils.print_networks([self.Gis, self.Gsi, self.Di, self.Ds],
                             ['Gis', 'Gsi', 'Di', 'Ds'])

        self.args = args

        self.MSE = nn.MSELoss()
        self.L1 = nn.L1Loss()
        self.CE = nn.CrossEntropyLoss()
        self.activation_softmax = nn.Softmax2d()

        ### Tensorboard writer
        self.writer_semisuper = SummaryWriter(tensorboard_loc + '_semisuper')
        self.running_metrics_val = utils.runningScore(self.n_channels,
                                                      args.dataset)

        ### For adding gaussian noise
        self.gauss_noise = utils.GaussianNoise(sigma=0.2)

        # Optimizers
        #####################################################
        self.g_optimizer = torch.optim.Adam(itertools.chain(
            self.Gis.parameters(), self.Gsi.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(itertools.chain(
            self.Di.parameters(), self.Ds.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))

        self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.g_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
        self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.d_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)

        # Try loading checkpoint
        #####################################################
        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)

        try:
            ckpt = utils.load_checkpoint('%s/latest_semisuper_cycleGAN.ckpt' %
                                         (args.checkpoint_dir))
            self.start_epoch = ckpt['epoch']
            self.Di.load_state_dict(ckpt['Di'])
            self.Ds.load_state_dict(ckpt['Ds'])
            self.Gis.load_state_dict(ckpt['Gis'])
            self.Gsi.load_state_dict(ckpt['Gsi'])
            self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
            self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
            self.best_iou = ckpt['best_iou']
        except:
            print(' [*] No checkpoint!')
            self.start_epoch = 0
            self.best_iou = -100
Пример #9
0
    def __init__(self, args):

        # Generators and Discriminators
        self.G_AtoB = define_Gen(input_nc=3,
                                 output_nc=3,
                                 ngf=args.ngf,
                                 norm=args.norm,
                                 use_dropout=not args.no_dropout,
                                 gpu_ids=args.gpu_ids)
        self.G_BtoA = define_Gen(input_nc=3,
                                 output_nc=3,
                                 ngf=args.ngf,
                                 norm=args.norm,
                                 use_dropout=not args.no_dropout,
                                 gpu_ids=args.gpu_ids)
        self.D_A = define_Dis(input_nc=3,
                              ndf=args.ndf,
                              norm=args.norm,
                              gpu_ids=args.gpu_ids)
        self.D_B = define_Dis(input_nc=3,
                              ndf=args.ndf,
                              norm=args.norm,
                              gpu_ids=args.gpu_ids)

        utils.print_networks([self.G_AtoB, self.G_BtoA, self.D_A, self.D_B],
                             ['G_AtoB', 'G_BtoA', 'D_A', 'D_B'])

        # MSE loss and L1 loss
        self.MSE = nn.MSELoss()
        self.L1 = nn.L1Loss()

        # Optimizers and lr_scheduler
        self.g_optimizer = torch.optim.Adam(itertools.chain(
            self.G_AtoB.parameters(), self.G_BtoA.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(itertools.chain(
            self.D_A.parameters(), self.D_B.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))

        self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.g_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
        self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.d_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)

        # Check if there is a checkpoint
        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)

        try:
            ckpt = utils.load_checkpoint('%s/latest.ckpt' %
                                         (args.checkpoint_dir))
            self.start_epoch = ckpt['epoch']
            self.D_A.load_state_dict(ckpt['D_A'])
            self.D_B.load_state_dict(ckpt['D_B'])
            self.G_AtoB.load_state_dict(ckpt['G_AtoB'])
            self.G_BtoA.load_state_dict(ckpt['G_BtoA'])
            self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
            self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
        except:
            print(' [*] No checkpoint! Train from the beginning! ')
            self.start_epoch = 0
Пример #10
0
    def __init__(self, args):
        # Define the network
        #####################################################
        '''
        Define the network:
        Two generators: Gab, Gba
        Two discriminators: Da, Db
        '''
        self.Gab = define_Gen(input_nc=3,
                              output_nc=3,
                              ngf=args.ngf,
                              netG=args.gen_net,
                              norm=args.norm,
                              use_dropout=not args.no_dropout,
                              gpu_ids=args.gpu_ids)
        self.Gba = define_Gen(input_nc=3,
                              output_nc=3,
                              ngf=args.ngf,
                              netG=args.gen_net,
                              norm=args.norm,
                              use_dropout=not args.no_dropout,
                              gpu_ids=args.gpu_ids)
        self.Da = define_Dis(input_nc=3,
                             ndf=args.ndf,
                             netD=args.dis_net,
                             n_layers_D=3,
                             norm=args.norm,
                             gpu_ids=args.gpu_ids)
        self.Db = define_Dis(input_nc=3,
                             ndf=args.ndf,
                             netD=args.dis_net,
                             n_layers_D=3,
                             norm=args.norm,
                             gpu_ids=args.gpu_ids)

        utils.print_networks([self.Gab, self.Gba, self.Da, self.Db],
                             ['Gab', 'Gba', 'Da', 'Db'])

        # Define loss criteria
        self.identity_criteron = nn.L1Loss()
        self.adversarial_criteron = nn.MSELoss()
        self.cycle_consistency_criteron = nn.L1Loss()

        # Define optimizers
        self.g_optimizer = torch.optim.Adam(itertools.chain(
            self.Gab.parameters(), self.Gba.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(itertools.chain(
            self.Da.parameters(), self.Db.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))

        # Define learning rate schedulers
        self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.g_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
        self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.d_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)

        # Try loading checkpoint
        #####################################################
        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)

        try:
            ckpt = utils.load_checkpoint('%s/latest.ckpt' %
                                         (args.checkpoint_dir))
            self.start_epoch = ckpt['epoch']
            self.Da.load_state_dict(ckpt['Da'])
            self.Db.load_state_dict(ckpt['Db'])
            self.Gab.load_state_dict(ckpt['Gab'])
            self.Gba.load_state_dict(ckpt['Gba'])
            self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
            self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
        except:
            print(' [*] No checkpoint!')
            self.start_epoch = 0

        # Tensorboard Setup
        # current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        current_time = '20201024-102158'
        train_log_dir = 'logs/sketch2pokemon/' + current_time
        self.writer = SummaryWriter(train_log_dir)

        # Stability variables setup
        self.last_test_output = []
        self.cur_test_output = []
Пример #11
0
    # Optimizers & LR schedulers
    optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(),
                                                   netG_B2A.parameters()),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=ut.LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=ut.LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=ut.LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    # Inputs & targets memory allocation
    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
    target_real = Variable(Tensor(opt.batchSize, 1).fill_(1.0),
                           requires_grad=False)
    target_fake = Variable(Tensor(opt.batchSize, 1).fill_(0.0),
                           requires_grad=False)
Пример #12
0
    def __init__(self, args):
        # Device
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')

        # Generator Networks
        self.G_AB = ResnetGenerator(input_channels=3,
                                    output_channels=3,
                                    ngf=args.ngf,
                                    normalization=args.norm,
                                    use_dropout=not args.no_dropout).to(
                                        self.device)  # A-> B

        self.G_BA = ResnetGenerator(input_channels=3,
                                    output_channels=3,
                                    ngf=args.ngf,
                                    normalization=args.norm,
                                    use_dropout=not args.no_dropout).to(
                                        self.device)  # B-> A

        # Discriminator Networks
        if args.train:
            self.D_A = Discriminator(input_channels=3,
                                     ndf=args.ndf,
                                     normalization=args.norm).to(self.device)
            self.D_B = Discriminator(input_channels=3,
                                     ndf=args.ndf,
                                     normalization=args.norm).to(self.device)

            # Losses
            self.MSE = nn.MSELoss()
            self.L1 = nn.L1Loss()

            # Training items
            self.curr_epoch = 0

            self.gen_optimizer = torch.optim.Adam(
                list(self.G_AB.parameters()) + list(self.G_BA.parameters()),
                lr=args.lr,
                betas=(0.5, 0.999))
            self.dis_optimizer = torch.optim.Adam(list(self.D_A.parameters()) +
                                                  list(self.D_B.parameters()),
                                                  lr=args.lr,
                                                  betas=(0.5, 0.999))

            self.gen_scheduler = torch.optim.lr_scheduler.LambdaLR(
                self.gen_optimizer,
                lr_lambda=utils.LambdaLR(args.epochs, args.decay_epoch).step)
            self.dis_scheduler = torch.optim.lr_scheduler.LambdaLR(
                self.dis_optimizer,
                lr_lambda=utils.LambdaLR(args.epochs, args.decay_epoch).step)

        # Transforms
        # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/2c5f2b14a577753b6ce40716e42dc28b21ed775a/data/base_dataset.py#L81
        # and from default base options
        # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/options/base_options.py
        self.train_transforms = transforms.Compose([
            transforms.Resize(args.load_size, Image.BICUBIC),
            transforms.RandomCrop(args.crop_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.test_transforms = transforms.Compose([
            transforms.Resize(args.crop_size, Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])