def build_model(self): self.G = Generator() self.D = Discriminator() self.C = DomainClassifier() self.g_optimizer = flow.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = flow.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) self.c_optimizer = flow.optim.Adam(self.C.parameters(), self.c_lr, [self.beta1, self.beta2]) self.print_network(self.G, "G") self.print_network(self.D, "D") self.print_network(self.C, "C") self.G.to(self.device) self.D.to(self.device) self.C.to(self.device)
def test_all_level_no_mask_yes_attr(args): """Test model with input image and attributes.""" transform = transforms.Compose( [Normalize(0.5, 0.5), CenterSquareMask(), ScaleNRotate(), ToTensor()]) batch_size = 1 num_attrs = 40 resolutions_to = [4, 8, 8, 16, 16, 32, 32, 64, 64, 128, 128, 256, 256] # 512, 512] levels = [1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 5.5, 6, 6.5, 7] # 7.5, 8] data_shape = [batch_size, 3, 512, 512] G = Generator(data_shape, use_mask=False, num_attrs=num_attrs) D = Discriminator(data_shape, num_attrs=num_attrs) for res, lev in zip(resolutions_to, levels): dataset = CelebAHQDataset(args.data_dir, res, transform) dataloader = DataLoader(dataset, batch_size, True) sample = iter(dataloader).next() # noqa: B305 image = sample['image'] masked_image = sample['masked_image'] mask = sample['mask'] attr = sample['attr'] print(f"level: {lev}, resolution: {res}, image: {masked_image.shape}, \ mask: {mask.shape}") # Generator if isinstance(lev, int): # training state fake_image1 = G(masked_image, attr, cur_level=lev) assert list(fake_image1.shape) == [batch_size, 3, res, res], \ f'{res, lev} test failed' else: # transition state fake_image2 = G(masked_image, attr, cur_level=lev) assert list(fake_image2.shape) == [batch_size, 3, res, res], \ f'{res, lev} test failed' # Discriminator if isinstance(lev, int): # training state cls1, attr1 = D(image, lev) assert list(cls1.shape) == [batch_size, 1], \ f'{res, lev} test failed' assert list(attr1.shape) == [batch_size, num_attrs], \ f'{res, lev} test failed' else: # transition state cls2, attr2 = D(image, lev) assert list(cls2.shape) == [batch_size, 1], \ f'{res, lev} test failed' assert list(attr2.shape) == [batch_size, num_attrs], \ f'{res, lev} test failed'
def test_end_to_end(args): """Test end to end data handling process.""" batch_size = 1 resolutions = [256, 256] levels = [7, 7] num_classes = 5 num_layers = 1 data_shape = [batch_size, 3, 256, 256] transform = transforms.Compose([ Normalize(0.5, 0.5), TargetMask(num_classes), ScaleNRotate(), ToTensor() ]) G = Generator(data_shape) D = Discriminator(data_shape, num_classes, num_layers) for res, lev in zip(resolutions, levels): dataset = VGGFace2Dataset(args.data_dir, res, args.landmark_info_path, args.identity_info_path, transform) dataloader = DataLoader(dataset, batch_size, True) sample = iter(dataloader).next() # noqa: B305 image = sample['image'] real_mask = sample['real_mask'] obs_mask = sample['obs_mask'] target_id = sample['target_id'] print(f"lev: {lev}, res: {res}, image: {image.shape}, \ mask: {real_mask.shape}, {obs_mask.shape}, \ target_id: {target_id}") # Generator fake_image = G(image, obs_mask, cur_level=lev) assert list(fake_image.shape) == [batch_size, 3, res, res], \ f'Generator: {res, lev} test failed' # Discriminator (original) cls1, pix_cls1 = D(image, lev) assert list(cls1.shape) == [batch_size, 1], \ f'Discriminator: {res, lev} test failed' assert list(pix_cls1.shape) == [batch_size, num_classes, res, res], \ f'Pixel-Discriminator: {res, lev} test failed' cls2, pix_cls2 = D(fake_image, lev) assert list(cls2.shape) == [batch_size, 1], \ f'Discriminator: {res, lev} test failed' assert list(pix_cls2.shape) == [batch_size, num_classes, res, res], \ f'Pixel-Discriminator: {res, lev} test failed'
def run(args): # create output path Path(hp.output_path).mkdir(parents=True, exist_ok=True) # setup nnabla context ctx = get_extension_context(args.context, device_id='0') nn.set_default_context(ctx) hp.comm = CommunicatorWrapper(ctx) hp.event = StreamEventHandler(int(hp.comm.ctx.device_id)) with open(hp.speaker_dir) as f: hp.n_speakers = len(f.read().split('\n')) logger.info(f'Training data with {hp.n_speakers} speakers.') if hp.comm.n_procs > 1 and hp.comm.rank == 0: n_procs = hp.comm.n_procs logger.info(f'Distributed training with {n_procs} processes.') rng = np.random.RandomState(hp.seed) train_loader = data_iterator(VCTKDataSource('metadata_train.csv', hp, shuffle=True, rng=rng), batch_size=hp.batch_size, with_memory_cache=False, rng=rng) dataloader = dict(train=train_loader, valid=None) gen = NVCNet(hp) gen_optim = Optimizer(weight_decay=hp.weight_decay, name='Adam', alpha=hp.g_lr, beta1=hp.beta1, beta2=hp.beta2) dis = Discriminator(hp) dis_optim = Optimizer(weight_decay=hp.weight_decay, name='Adam', alpha=hp.d_lr, beta1=hp.beta1, beta2=hp.beta2) Trainer(gen, gen_optim, dis, dis_optim, dataloader, rng, hp).run()
def save(ckpt_dir_path, global_step, global_init_step, generator: Generator, discriminator: Discriminator, gen_optimizer, disc_optimizer): if not os.path.exists(ckpt_dir_path): os.makedirs(ckpt_dir_path) torch.save(generator.state_dict(), os.path.join(ckpt_dir_path, "generator.pt")) torch.save(discriminator.state_dict(), os.path.join(ckpt_dir_path, "discriminator.pt")) torch.save(gen_optimizer.state_dict(), os.path.join(ckpt_dir_path, "generator_optimizer.pt")) torch.save(disc_optimizer.state_dict(), os.path.join(ckpt_dir_path, "discriminator_optimizer.pt")) with open(os.path.join(ckpt_dir_path, "learning_state.json"), 'w') as f: json.dump( { 'global_step': global_step, 'global_init_step': global_init_step, }, f, indent='\t')
def __init__(self, config): """Class initializer. 1. Read self.configurations from self.config.py 2. Check gpu availability 3. Create a model and training related objects - Model (Generator, Discriminator) - Optimizer - Loss and loss histories - Replay memory - Snapshot """ self.config = config self.D_repeats = self.config.train.D_repeats self.total_size = int(self.config.train.total_size * self.config.train.dataset_unit) self.train_size = int(self.config.train.train_size * self.config.train.dataset_unit) self.transition_size = int(self.config.train.transition_size * self.config.train.dataset_unit) assert (self.total_size == (self.train_size + self.transition_size)) \ and self.train_size > 0 and self.transition_size > 0 # GPU self.check_gpu() self.mode = self.config.train.mode self.use_mask = self.config.train.use_mask # Data Shape dataset_shape = [ 1, self.config.dataset.num_channels, self.config.train.net.max_resolution, self.config.train.net.max_resolution ] # Generator & Discriminator Creation self.G = Generator(dataset_shape, fmap_base=self.config.train.net.fmap_base, fmap_min=self.config.train.net.min_resolution, fmap_max=self.config.train.net.max_resolution, latent_size=self.config.train.net.latent_size, use_mask=self.use_mask, leaky_relu=True, instancenorm=True) spectralnorm = True if self.config.loss.gan == Gan.sngan else False self.D = Discriminator(dataset_shape, num_classes=self.config.dataset.num_classes, num_layers=self.config.train.net.num_layers, fmap_base=self.config.train.net.fmap_base, fmap_min=self.config.train.net.min_resolution, fmap_max=self.config.train.net.max_resolution, latent_size=self.config.train.net.latent_size, leaky_relu=True, instancenorm=True, spectralnorm=spectralnorm) self.register_on_gpu() self.create_optimizer() # Loss self.loss = FaceGenLoss(self.config, self.use_cuda, self.config.env.num_gpus) # Replay Memory self.replay_memory = ReplayMemory(self.config, self.use_cuda, self.config.replay.enabled) self.global_it = 1 self.global_cur_nimg = 1 # restore self.snapshot = Snapshot(self.config, self.use_cuda) self.snapshot.prepare_logging() self.snapshot.restore_model(self.G, self.D, self.optim_G, self.optim_D)
class FaceGen(): """FaceGen Classes. Attributes: D_repeats : How many times the discriminator is trained per G iteration total_size : Total # of real images in the training train_size : # of real images to show before doubling the resolution transition_size : # of real images to show when fading in new layers mode : running mode {inpainting , generation} use_mask : flag for mask use in the model dataset_shape : input data shape use_cuda : flag for cuda use G : generator D : discriminator optim_G : optimizer for generator optim_D : optimizer for discriminator loss : losses of generator and discriminator replay_memory : replay memory global_it : global # of iterations through training global_cur_nimg : global # of current images through training snapshot : snapshot intermediate images, checkpoints, tensorboard logs real : real images obs: observed images mask : binary mask syn : synthesized images cls_real : classes for real images cls_syn : classes for synthesized images pix_cls_real : pixelwise classes for real images pix_cls_syn : pixelwise classes for synthesized images """ def __init__(self, config): """Class initializer. 1. Read self.configurations from self.config.py 2. Check gpu availability 3. Create a model and training related objects - Model (Generator, Discriminator) - Optimizer - Loss and loss histories - Replay memory - Snapshot """ self.config = config self.D_repeats = self.config.train.D_repeats self.total_size = int(self.config.train.total_size * self.config.train.dataset_unit) self.train_size = int(self.config.train.train_size * self.config.train.dataset_unit) self.transition_size = int(self.config.train.transition_size * self.config.train.dataset_unit) assert (self.total_size == (self.train_size + self.transition_size)) \ and self.train_size > 0 and self.transition_size > 0 # GPU self.check_gpu() self.mode = self.config.train.mode self.use_mask = self.config.train.use_mask # Data Shape dataset_shape = [ 1, self.config.dataset.num_channels, self.config.train.net.max_resolution, self.config.train.net.max_resolution ] # Generator & Discriminator Creation self.G = Generator(dataset_shape, fmap_base=self.config.train.net.fmap_base, fmap_min=self.config.train.net.min_resolution, fmap_max=self.config.train.net.max_resolution, latent_size=self.config.train.net.latent_size, use_mask=self.use_mask, leaky_relu=True, instancenorm=True) spectralnorm = True if self.config.loss.gan == Gan.sngan else False self.D = Discriminator(dataset_shape, num_classes=self.config.dataset.num_classes, num_layers=self.config.train.net.num_layers, fmap_base=self.config.train.net.fmap_base, fmap_min=self.config.train.net.min_resolution, fmap_max=self.config.train.net.max_resolution, latent_size=self.config.train.net.latent_size, leaky_relu=True, instancenorm=True, spectralnorm=spectralnorm) self.register_on_gpu() self.create_optimizer() # Loss self.loss = FaceGenLoss(self.config, self.use_cuda, self.config.env.num_gpus) # Replay Memory self.replay_memory = ReplayMemory(self.config, self.use_cuda, self.config.replay.enabled) self.global_it = 1 self.global_cur_nimg = 1 # restore self.snapshot = Snapshot(self.config, self.use_cuda) self.snapshot.prepare_logging() self.snapshot.restore_model(self.G, self.D, self.optim_G, self.optim_D) def train(self): """Training for progressive growing model. 1. Calculate min/max resolution for a model 2. for each layer 2-1. for each phases 1) first layer : {training} 2) remainder layers : {transition, traning} 3) optional : {replaying} do train one step """ min_resol = int(np.log2(self.config.train.net.min_resolution)) max_resol = int(np.log2(self.config.train.net.max_resolution)) assert 2**max_resol == self.config.train.net.max_resolution \ and 2**min_resol == self.config.train.net.min_resolution \ and max_resol >= min_resol >= 2 from_resol = min_resol if self.snapshot.is_restored: from_resol = int(np.log2(self.snapshot._resolution)) self.global_it = self.snapshot._global_it assert from_resol <= max_resol prev_time = datetime.datetime.now() # layer iteration for R in range(from_resol, max_resol + 1): # Resolution & batch size cur_resol = 2**R if self.config.train.forced_stop \ and cur_resol > self.config.train.forced_stop_resolution: break batch_size = self.config.sched.batch_dict[cur_resol] assert batch_size >= 1 train_iter = self.train_size // batch_size transition_iter = self.transition_size // batch_size assert (train_iter != 0) and (transition_iter != 0) cur_time = datetime.datetime.now() print("Layer Training Time : ", cur_time - prev_time) prev_time = cur_time print("********** New Layer [%d x %d] : batch_size %d **********" % (cur_resol, cur_resol, batch_size)) # Phase if R == min_resol: phases = {Phase.training: [1, train_iter]} phase = Phase.training total_it = train_iter else: phases = { Phase.transition: [1, transition_iter], Phase.training: [train_iter + 1, train_iter + transition_iter] } phase = Phase.transition total_it = train_iter + transition_iter if self.snapshot.is_restored: phase = self.snapshot._phase # Iteration from_it = phases[phase][0] to_it = phases[phase][1] if self.snapshot.is_restored: from_it = self.snapshot._it + 1 self.snapshot.is_restored = False cur_nimg = from_it * batch_size cur_it = from_it # load traninig set self.training_set = self.load_train_set(cur_resol, batch_size) if len(self.training_set) == 0: print("DataLoding is failed") return if self.config.replay.enabled: self.replay_memory.reset(cur_resol) # Learningn Rate lrate = self.config.optimizer.lrate self.G_lrate = lrate.G_dict.get(cur_resol, self.config.optimizer.lrate.G_base) self.D_lrate = lrate.D_dict.get(cur_resol, self.config.optimizer.lrate.D_base) # Training Set replay_mode = False while cur_it <= total_it: for _, sample_batched in enumerate(self.training_set): if sample_batched['image'].shape[0] < batch_size: break if cur_it > total_it: break # trasnfer tansition to training if cur_it == to_it and cur_it < total_it: phase = Phase.training # calculate current level (from 1) if phase == Phase.transition: # transition [pref level, current level] cur_level = float(R - min_resol + float(cur_it / to_it)) else: # training cur_level = float(R - min_resol + 1) self.real = sample_batched['image'] self.real_mask = sample_batched['real_mask'] self.obs = sample_batched['image'] self.obs_mask = sample_batched['obs_mask'] self.source_domain = sample_batched['gender'] self.target_domain = sample_batched['fake_gender'] cur_nimg = self.train_step(batch_size, cur_it, total_it, phase, cur_resol, cur_level, cur_nimg) cur_it += 1 self.global_it += 1 self.global_cur_nimg += 1 # Replay Mode if self.config.replay.enabled: replay_mode = True phase = Phase.replaying total_it = self.config.replay.replay_count for i_batch in range(self.config.replay.replay_count): cur_it = i_batch + 1 self.real, self.real_mask, \ self.obs, self.obs_mask, self.syn \ = self.replay_memory.get_batch(cur_resol, batch_size) if self.real is None: break self.syn = util.tofloat(self.use_cuda, self.syn) cur_nimg = self.train_step(batch_size, cur_it, total_it, phase, cur_resol, cur_level, cur_nimg, replay_mode) def train_step(self, batch_size, cur_it, total_it, phase, cur_resol, cur_level, cur_nimg, replay_mode=False): """Training one step. 1. Train discrmininator for [D_repeats] 2. Train generator 3. Snapshot Args: batch_size: batch size cur_it: current # of iterations in the phases of the layer total_it: total # of iterations in the phases of the layer phase: training, transition, replaying cur_resol: image resolution of current layer cur_level: progress indicator of progressive growing network cur_nimg: current # of images in the phase replay_mode: Memory replay mode Returns: cur_nimg: updated # of images in the phase """ self.preprocess() # Training discriminator d_cnt = 0 if d_cnt < self.D_repeats: self.update_lr(cur_it, total_it, replay_mode) self.optim_D.zero_grad() self.forward_D(cur_level, detach=True, replay_mode=replay_mode) self.backward_D(cur_level) if self.config.replay.enabled and replay_mode is False: self.replay_memory.append(cur_resol, self.real, self.real_mask, self.obs, self.obs_mask, self.syn.detach()) d_cnt += 1 # Training generator if d_cnt == self.D_repeats: # Training generator self.optim_G.zero_grad() self.forward_G(cur_level) self.backward_G(cur_level) d_cnt = 0 # model intermediate results self.snapshot.snapshot(self.global_it, cur_it, total_it, phase, cur_resol, cur_level, batch_size, self.real, self.syn, self.G, self.D, self.optim_G, self.optim_D, self.loss.g_losses, self.loss.d_losses) cur_nimg += batch_size return cur_nimg def forward_G(self, cur_level): """Forward generator. Args: cur_level: progress indicator of progressive growing network """ self.cls_syn, self.pixel_cls_syn = self.D(self.syn, cur_level=cur_level) def forward_D(self, cur_level, detach=True, replay_mode=False): """Forward discriminator. Args: cur_level: progress indicator of progressive growing network detach: flag whether to detach graph from generator or not replay_mode: memory replay mode """ if replay_mode is False: self.syn = self.G(self.obs, mask=self.obs_mask, cur_level=cur_level) # self.syn = util.normalize_min_max(self.syn) self.cls_real, self.pixel_cls_real = self.D(self.real, cur_level=cur_level) self.cls_syn, self.pixel_cls_syn = self.D( self.syn.detach() if detach else self.syn, cur_level=cur_level) def backward_G(self, cur_level): """Backward generator.""" self.loss.calc_G_loss(self.G, cur_level, self.real, self.real_mask, self.obs, self.obs_mask, self.syn, self.cls_real, self.cls_syn, self.pixel_cls_real, self.pixel_cls_syn) self.loss.g_losses.g_loss.backward() self.optim_G.step() def backward_D(self, cur_level, retain_graph=True): """Backward discriminator. Args: cur_level: progress indicator of progressive growing network retain_graph: flag whether to retain graph of discriminator or not """ self.loss.calc_D_loss(self.D, cur_level, self.real, self.real_mask, self.obs, self.obs_mask, self.source_domain, self.target_domain, self.syn, self.cls_real, self.cls_syn, self.pixel_cls_real, self.pixel_cls_syn) self.loss.d_losses.d_loss.backward(retain_graph=retain_graph) self.optim_D.step() def preprocess(self): """Set input type to cuda or cpu according to gpu availability.""" self.real = util.tofloat(self.use_cuda, self.real) self.real_mask = util.tofloat(self.use_cuda, self.real_mask) self.obs = util.tofloat(self.use_cuda, self.obs) self.obs_mask = util.tofloat(self.use_cuda, self.obs_mask) self.source_domain = util.tofloat(self.use_cuda, self.source_domain) self.target_domain = util.tofloat(self.use_cuda, self.target_domain) def check_gpu(self): """Check gpu availability.""" self.use_cuda = torch.cuda.is_available() \ and self.config.env.num_gpus > 0 if self.use_cuda: gpus = str(list(range(self.config.env.num_gpus))) os.environ['CUDA_VISIBLE_DEVICES'] = gpus def register_on_gpu(self): """Set model to cuda according to gpu availability.""" if self.use_cuda: self.G.cuda() self.D.cuda() def load_train_set(self, resol, batch_size): """Load train set. Args: resol: progress indicator of progressive growing network batch_size: flag for detaching syn image from generator graph """ num_classes = self.config.dataset.num_classes transform_options = transforms.Compose([ dt.Normalize(0.5, 0.5), dt.PolygonMask(num_classes), dt.ToTensor() ]) dataset_func = self.config.dataset.func ds = self.config.dataset datasets = util.call_func_by_name(data_dir=ds.data_dir, resolution=resol, landmark_info_path=ds.landmark_path, identity_info_path=ds.identity_path, filtered_list=ds.filtering_path, transform=transform_options, func=dataset_func) # train_dataset & data loader return DataLoader(datasets, batch_size, True) def create_optimizer(self): """Create optimizers of generator and discriminator.""" self.optim_G = optim.Adam(self.G.parameters(), lr=self.config.optimizer.lrate.G_base, betas=(self.config.optimizer.G_opt.beta1, self.config.optimizer.G_opt.beta2)) self.optim_D = optim.Adam(self.D.parameters(), lr=self.config.optimizer.lrate.D_base, betas=(self.config.optimizer.D_opt.beta1, self.config.optimizer.D_opt.beta2)) def rampup(self, cur_it, rampup_it): """Ramp up learning rate. Args: cur_it: current # of iterations in the phase rampup_it: # of iterations for ramp up """ if cur_it < rampup_it: p = max(0.0, float(cur_it)) / float(rampup_it) p = 1.0 - p return np.exp(-p * p * 5.0) else: return 1.0 def rampdown_linear(self, cur_it, total_it, rampdown_it): """Ramp down learning rate. Args: cur_it: current # of iterations in the phasek total_it: total # of iterations in the phase rampdown_it: # of iterations for ramp down """ if cur_it >= total_it - rampdown_it: return float(total_it - cur_it) / rampdown_it return 1.0 def update_lr(self, cur_it, total_it, replay_mode=False): """Update learning rate. Args: cur_it: current # of iterations in the phasek total_it: total # of iterations in the phase replay_mode: memory replay mode """ if replay_mode: return rampup_it = total_it * self.config.optimizer.lrate.rampup_rate rampdown_it = total_it * self.config.optimizer.lrate.rampdown_rate # learning rate rampup & down for param_group in self.optim_G.param_groups: lrate_coef = self.rampup(cur_it, rampup_it) lrate_coef *= self.rampdown_linear(cur_it, total_it, rampdown_it) param_group['lr'] = lrate_coef * self.G_lrate print("learning rate %f" % (param_group['lr'])) for param_group in self.optim_D.param_groups: lrate_coef = self.rampup(cur_it, rampup_it) lrate_coef *= self.rampdown_linear(cur_it, self.total_size, rampdown_it) param_group['lr'] = lrate_coef * self.D_lrate
def __init__(self, args): """ Args: args (Namespace): Program arguments from argparser """ # Store args self.num_epochs = args.num_epochs self.start_epoch = args.start_epoch self.generator_lr = args.generator_lr self.discriminator_lr = args.discriminator_lr self.decay_after = args.decay_after self.mini_batch_size = args.batch_size self.cycle_loss_lambda = args.cycle_loss_lambda self.identity_loss_lambda = args.identity_loss_lambda self.device = args.device self.epochs_per_save = args.epochs_per_save self.sample_rate = args.sample_rate self.validation_A_dir = os.path.join(args.origin_data_dir, args.speaker_A_id) self.output_A_dir = os.path.join(args.output_data_dir, args.speaker_A_id) self.validation_B_dir = os.path.join(args.origin_data_dir, args.speaker_B_id) self.output_B_dir = os.path.join(args.output_data_dir, args.speaker_B_id) self.infer_data_dir = args.infer_data_dir self.pretrain_models = args.pretrain_models # Initialize speakerA's dataset self.dataset_A = self.loadPickleFile( os.path.join( args.preprocessed_data_dir, args.speaker_A_id, f"{args.speaker_A_id}_normalized.pickle", )) dataset_A_norm_stats = np.load( os.path.join( args.preprocessed_data_dir, args.speaker_A_id, f"{args.speaker_A_id}_norm_stat.npz", )) self.dataset_A_mean = dataset_A_norm_stats["mean"] self.dataset_A_std = dataset_A_norm_stats["std"] # Initialize speakerB's dataset self.dataset_B = self.loadPickleFile( os.path.join( args.preprocessed_data_dir, args.speaker_B_id, f"{args.speaker_B_id}_normalized.pickle", )) dataset_B_norm_stats = np.load( os.path.join( args.preprocessed_data_dir, args.speaker_B_id, f"{args.speaker_B_id}_norm_stat.npz", )) self.dataset_B_mean = dataset_B_norm_stats["mean"] self.dataset_B_std = dataset_B_norm_stats["std"] # Compute lr decay rate self.n_samples = len(self.dataset_A) print(f"n_samples = {self.n_samples}") self.generator_lr_decay = self.generator_lr / float( self.num_epochs * (self.n_samples // self.mini_batch_size)) self.discriminator_lr_decay = self.discriminator_lr / float( self.num_epochs * (self.n_samples // self.mini_batch_size)) print(f"generator_lr_decay = {self.generator_lr_decay}") print(f"discriminator_lr_decay = {self.discriminator_lr_decay}") # Initialize Train Dataloader self.num_frames = args.num_frames self.dataset = VCDataset( datasetA=self.dataset_A, datasetB=self.dataset_B, n_frames=args.num_frames, max_mask_len=args.max_mask_len, ) self.train_dataloader = flow.utils.data.DataLoader( dataset=self.dataset, batch_size=self.mini_batch_size, shuffle=True, drop_last=False, ) # Initialize Generators and Discriminators self.generator_A2B = Generator().to(self.device) self.generator_B2A = Generator().to(self.device) self.discriminator_A = Discriminator().to(self.device) self.discriminator_B = Discriminator().to(self.device) # Discriminator to compute 2 step adversarial loss self.discriminator_A2 = Discriminator().to(self.device) # Discriminator to compute 2 step adversarial loss self.discriminator_B2 = Discriminator().to(self.device) # Initialize Optimizers g_params = list(self.generator_A2B.parameters()) + list( self.generator_B2A.parameters()) d_params = (list(self.discriminator_A.parameters()) + list(self.discriminator_B.parameters()) + list(self.discriminator_A2.parameters()) + list(self.discriminator_B2.parameters())) self.generator_optimizer = flow.optim.Adam(g_params, lr=self.generator_lr, betas=(0.5, 0.999)) self.discriminator_optimizer = flow.optim.Adam( d_params, lr=self.discriminator_lr, betas=(0.5, 0.999))
class MaskCycleGANVCTrainer(object): """Trainer for MaskCycleGAN-VC """ def __init__(self, args): """ Args: args (Namespace): Program arguments from argparser """ # Store args self.num_epochs = args.num_epochs self.start_epoch = args.start_epoch self.generator_lr = args.generator_lr self.discriminator_lr = args.discriminator_lr self.decay_after = args.decay_after self.mini_batch_size = args.batch_size self.cycle_loss_lambda = args.cycle_loss_lambda self.identity_loss_lambda = args.identity_loss_lambda self.device = args.device self.epochs_per_save = args.epochs_per_save self.sample_rate = args.sample_rate self.validation_A_dir = os.path.join(args.origin_data_dir, args.speaker_A_id) self.output_A_dir = os.path.join(args.output_data_dir, args.speaker_A_id) self.validation_B_dir = os.path.join(args.origin_data_dir, args.speaker_B_id) self.output_B_dir = os.path.join(args.output_data_dir, args.speaker_B_id) self.infer_data_dir = args.infer_data_dir self.pretrain_models = args.pretrain_models # Initialize speakerA's dataset self.dataset_A = self.loadPickleFile( os.path.join( args.preprocessed_data_dir, args.speaker_A_id, f"{args.speaker_A_id}_normalized.pickle", )) dataset_A_norm_stats = np.load( os.path.join( args.preprocessed_data_dir, args.speaker_A_id, f"{args.speaker_A_id}_norm_stat.npz", )) self.dataset_A_mean = dataset_A_norm_stats["mean"] self.dataset_A_std = dataset_A_norm_stats["std"] # Initialize speakerB's dataset self.dataset_B = self.loadPickleFile( os.path.join( args.preprocessed_data_dir, args.speaker_B_id, f"{args.speaker_B_id}_normalized.pickle", )) dataset_B_norm_stats = np.load( os.path.join( args.preprocessed_data_dir, args.speaker_B_id, f"{args.speaker_B_id}_norm_stat.npz", )) self.dataset_B_mean = dataset_B_norm_stats["mean"] self.dataset_B_std = dataset_B_norm_stats["std"] # Compute lr decay rate self.n_samples = len(self.dataset_A) print(f"n_samples = {self.n_samples}") self.generator_lr_decay = self.generator_lr / float( self.num_epochs * (self.n_samples // self.mini_batch_size)) self.discriminator_lr_decay = self.discriminator_lr / float( self.num_epochs * (self.n_samples // self.mini_batch_size)) print(f"generator_lr_decay = {self.generator_lr_decay}") print(f"discriminator_lr_decay = {self.discriminator_lr_decay}") # Initialize Train Dataloader self.num_frames = args.num_frames self.dataset = VCDataset( datasetA=self.dataset_A, datasetB=self.dataset_B, n_frames=args.num_frames, max_mask_len=args.max_mask_len, ) self.train_dataloader = flow.utils.data.DataLoader( dataset=self.dataset, batch_size=self.mini_batch_size, shuffle=True, drop_last=False, ) # Initialize Generators and Discriminators self.generator_A2B = Generator().to(self.device) self.generator_B2A = Generator().to(self.device) self.discriminator_A = Discriminator().to(self.device) self.discriminator_B = Discriminator().to(self.device) # Discriminator to compute 2 step adversarial loss self.discriminator_A2 = Discriminator().to(self.device) # Discriminator to compute 2 step adversarial loss self.discriminator_B2 = Discriminator().to(self.device) # Initialize Optimizers g_params = list(self.generator_A2B.parameters()) + list( self.generator_B2A.parameters()) d_params = (list(self.discriminator_A.parameters()) + list(self.discriminator_B.parameters()) + list(self.discriminator_A2.parameters()) + list(self.discriminator_B2.parameters())) self.generator_optimizer = flow.optim.Adam(g_params, lr=self.generator_lr, betas=(0.5, 0.999)) self.discriminator_optimizer = flow.optim.Adam( d_params, lr=self.discriminator_lr, betas=(0.5, 0.999)) def adjust_lr_rate(self, optimizer, generator): """Decays learning rate. Args: optimizer (torch.optim): torch optimizer generator (bool): Whether to adjust generator lr. """ if generator: self.generator_lr = max( 0.0, self.generator_lr - self.generator_lr_decay) for param_groups in optimizer.param_groups: param_groups["lr"] = self.generator_lr else: self.discriminator_lr = max( 0.0, self.discriminator_lr - self.discriminator_lr_decay) for param_groups in optimizer.param_groups: param_groups["lr"] = self.discriminator_lr def reset_grad(self): """Sets gradients of the generators and discriminators to zero before backpropagation. """ self.generator_optimizer.zero_grad() self.discriminator_optimizer.zero_grad() def loadPickleFile(self, fileName): """Loads a Pickle file. Args: fileName (str): pickle file path Returns: file object: The loaded pickle file object """ with open(fileName, "rb") as f: return pickle.load(f) def train(self): """Implements the training loop for MaskCycleGAN-VC """ for epoch in range(self.start_epoch, self.num_epochs + 1): for i, (real_A, mask_A, real_B, mask_B) in enumerate(self.train_dataloader): num_iterations = (self.n_samples // self.mini_batch_size) * epoch + i if num_iterations > 10000: self.identity_loss_lambda = 0 if num_iterations > self.decay_after: self.adjust_lr_rate(self.generator_optimizer, generator=True) self.adjust_lr_rate(self.generator_optimizer, generator=False) real_A = real_A.to(self.device, dtype=flow.float) mask_A = mask_A.to(self.device, dtype=flow.float) real_B = real_B.to(self.device, dtype=flow.float) mask_B = mask_B.to(self.device, dtype=flow.float) # Train Generator self.generator_A2B.train() self.generator_B2A.train() self.discriminator_A.eval() self.discriminator_B.eval() self.discriminator_A2.eval() self.discriminator_B2.eval() # Generator Feed Forward fake_B = self.generator_A2B(real_A, mask_A) cycle_A = self.generator_B2A(fake_B, flow.ones_like(fake_B)) fake_A = self.generator_B2A(real_B, mask_B) cycle_B = self.generator_A2B(fake_A, flow.ones_like(fake_A)) identity_A = self.generator_B2A(real_A, flow.ones_like(real_A)) identity_B = self.generator_A2B(real_B, flow.ones_like(real_B)) d_fake_A = self.discriminator_A(fake_A) d_fake_B = self.discriminator_B(fake_B) # For Two Step Adverserial Loss d_fake_cycle_A = self.discriminator_A2(cycle_A) d_fake_cycle_B = self.discriminator_B2(cycle_B) # Generator Cycle Loss cycleLoss = flow.mean(flow.abs(real_A - cycle_A)) + flow.mean( flow.abs(real_B - cycle_B)) # Generator Identity Loss identityLoss = flow.mean( flow.abs(real_A - identity_A)) + flow.mean( flow.abs(real_B - identity_B)) # Generator Loss g_loss_A2B = flow.mean((1 - d_fake_B)**2) g_loss_B2A = flow.mean((1 - d_fake_A)**2) # Generator Two Step Adverserial Loss generator_loss_A2B_2nd = flow.mean((1 - d_fake_cycle_B)**2) generator_loss_B2A_2nd = flow.mean((1 - d_fake_cycle_A)**2) # Total Generator Loss g_loss = (g_loss_A2B + g_loss_B2A + generator_loss_A2B_2nd + generator_loss_B2A_2nd + self.cycle_loss_lambda * cycleLoss + self.identity_loss_lambda * identityLoss) # Backprop for Generator self.reset_grad() g_loss.backward() self.generator_optimizer.step() # Train Discriminator self.generator_A2B.eval() self.generator_B2A.eval() self.discriminator_A.train() self.discriminator_B.train() self.discriminator_A2.train() self.discriminator_B2.train() # Discriminator Feed Forward d_real_A = self.discriminator_A(real_A) d_real_B = self.discriminator_B(real_B) d_real_A2 = self.discriminator_A2(real_A) d_real_B2 = self.discriminator_B2(real_B) generated_A = self.generator_B2A(real_B, mask_B) d_fake_A = self.discriminator_A(generated_A) # For Two Step Adverserial Loss A->B cycled_B = self.generator_A2B(generated_A, flow.ones_like(generated_A)) d_cycled_B = self.discriminator_B2(cycled_B) generated_B = self.generator_A2B(real_A, mask_A) d_fake_B = self.discriminator_B(generated_B) # For Two Step Adverserial Loss B->A cycled_A = self.generator_B2A(generated_B, flow.ones_like(generated_B)) d_cycled_A = self.discriminator_A2(cycled_A) # Loss Functions d_loss_A_real = flow.mean((1 - d_real_A)**2) d_loss_A_fake = flow.mean((0 - d_fake_A)**2) d_loss_A = (d_loss_A_real + d_loss_A_fake) / 2.0 d_loss_B_real = flow.mean((1 - d_real_B)**2) d_loss_B_fake = flow.mean((0 - d_fake_B)**2) d_loss_B = (d_loss_B_real + d_loss_B_fake) / 2.0 # Two Step Adverserial Loss d_loss_A_cycled = flow.mean((0 - d_cycled_A)**2) d_loss_B_cycled = flow.mean((0 - d_cycled_B)**2) d_loss_A2_real = flow.mean((1 - d_real_A2)**2) d_loss_B2_real = flow.mean((1 - d_real_B2)**2) d_loss_A_2nd = (d_loss_A2_real + d_loss_A_cycled) / 2.0 d_loss_B_2nd = (d_loss_B2_real + d_loss_B_cycled) / 2.0 # Final Loss for discriminator with the Two Step Adverserial Loss d_loss = (d_loss_A + d_loss_B) / 2.0 + (d_loss_A_2nd + d_loss_B_2nd) / 2.0 # Backprop for Discriminator self.reset_grad() d_loss.backward() self.discriminator_optimizer.step() if (i + 1) % 2 == 0: print( "Iter:{} Generator Loss:{:.4f} Discrimator Loss:{:.4f} GA2B:{:.4f} GB2A:{:.4f} G_id:{:.4f} G_cyc:{:.4f} D_A:{:.4f} D_B:{:.4f}" .format( num_iterations, g_loss.item(), d_loss.item(), g_loss_A2B, g_loss_B2A, identityLoss, cycleLoss, d_loss_A, d_loss_B, )) # Save each model checkpoint and validation if epoch % self.epochs_per_save == 0 and epoch != 0: self.saveModelCheckPoint(epoch, PATH="model_checkpoint") self.validation_for_A_dir() self.validation_for_B_dir() def infer(self): """Implements the infering loop for MaskCycleGAN-VC """ # load pretrain models self.loadModel(self.pretrain_models) num_mcep = 80 sampling_rate = self.sample_rate frame_period = 5.0 infer_A_dir = self.infer_data_dir print("Generating Validation Data B from A...") for file in os.listdir(infer_A_dir): filePath = os.path.join(infer_A_dir, file) wav, _ = librosa.load(filePath, sr=sampling_rate, mono=True) wav = preprocess.wav_padding(wav=wav, sr=sampling_rate, frame_period=frame_period, multiple=4) f0, timeaxis, sp, ap = preprocess.world_decompose( wav=wav, fs=sampling_rate, frame_period=frame_period) f0_converted = preprocess.pitch_conversion( f0=f0, mean_log_src=self.dataset_A_mean, std_log_src=self.dataset_A_std, mean_log_target=self.dataset_B_mean, std_log_target=self.dataset_B_std, ) coded_sp = preprocess.world_encode_spectral_envelop( sp=sp, fs=sampling_rate, dim=num_mcep) coded_sp_transposed = coded_sp.T coded_sp_norm = (coded_sp_transposed - self.dataset_A_mean) / self.dataset_A_std coded_sp_norm = np.array([coded_sp_norm]) if flow.cuda.is_available(): coded_sp_norm = flow.tensor(coded_sp_norm).cuda().float() else: coded_sp_norm = flow.tensor(coded_sp_norm).float() coded_sp_converted_norm = self.generator_A2B( coded_sp_norm, flow.ones_like(coded_sp_norm)) coded_sp_converted_norm = coded_sp_converted_norm.cpu().detach( ).numpy() coded_sp_converted_norm = np.squeeze(coded_sp_converted_norm) coded_sp_converted = ( coded_sp_converted_norm * self.dataset_B_std + self.dataset_B_mean) coded_sp_converted = coded_sp_converted.T coded_sp_converted = np.ascontiguousarray( coded_sp_converted).astype(np.double) decoded_sp_converted = preprocess.world_decode_spectral_envelop( coded_sp=coded_sp_converted, fs=sampling_rate) wav_transformed = preprocess.world_speech_synthesis( f0=f0_converted[0], decoded_sp=decoded_sp_converted, ap=ap, fs=sampling_rate, frame_period=frame_period, ) sf.write( os.path.join(infer_A_dir, "convert_" + os.path.basename(file)), wav_transformed, sampling_rate, ) def validation_for_A_dir(self): num_mcep = 80 sampling_rate = 22050 frame_period = 5.0 validation_A_dir = self.validation_A_dir output_A_dir = self.output_A_dir os.makedirs(output_A_dir, exist_ok=True) print("Generating Validation Data B from A...") for file in os.listdir(validation_A_dir): filePath = os.path.join(validation_A_dir, file) wav, _ = librosa.load(filePath, sr=sampling_rate, mono=True) wav = preprocess.wav_padding(wav=wav, sr=sampling_rate, frame_period=frame_period, multiple=4) f0, timeaxis, sp, ap = preprocess.world_decompose( wav=wav, fs=sampling_rate, frame_period=frame_period) f0_converted = preprocess.pitch_conversion( f0=f0, mean_log_src=self.dataset_A_mean, std_log_src=self.dataset_A_std, mean_log_target=self.dataset_B_mean, std_log_target=self.dataset_B_std, ) coded_sp = preprocess.world_encode_spectral_envelop( sp=sp, fs=sampling_rate, dim=num_mcep) coded_sp_transposed = coded_sp.T coded_sp_norm = (coded_sp_transposed - self.dataset_A_mean) / self.dataset_A_std coded_sp_norm = np.array([coded_sp_norm]) if flow.cuda.is_available(): coded_sp_norm = flow.tensor(coded_sp_norm).cuda().float() else: coded_sp_norm = flow.tensor(coded_sp_norm).float() coded_sp_converted_norm = self.generator_A2B( coded_sp_norm, flow.ones_like(coded_sp_norm)) coded_sp_converted_norm = coded_sp_converted_norm.cpu().detach( ).numpy() coded_sp_converted_norm = np.squeeze(coded_sp_converted_norm) coded_sp_converted = ( coded_sp_converted_norm * self.dataset_B_std + self.dataset_B_mean) coded_sp_converted = coded_sp_converted.T coded_sp_converted = np.ascontiguousarray( coded_sp_converted).astype(np.double) decoded_sp_converted = preprocess.world_decode_spectral_envelop( coded_sp=coded_sp_converted, fs=sampling_rate) wav_transformed = preprocess.world_speech_synthesis( f0=f0_converted[0], decoded_sp=decoded_sp_converted, ap=ap, fs=sampling_rate, frame_period=frame_period, ) sf.write( os.path.join(output_A_dir, "convert_" + os.path.basename(file)), wav_transformed, sampling_rate, ) def validation_for_B_dir(self): num_mcep = 80 sampling_rate = 22050 frame_period = 5.0 validation_B_dir = self.validation_B_dir output_B_dir = self.output_B_dir os.makedirs(output_B_dir, exist_ok=True) print("Generating Validation Data A from B...") for file in os.listdir(validation_B_dir): filePath = os.path.join(validation_B_dir, file) wav, _ = librosa.load(filePath, sr=sampling_rate, mono=True) wav = preprocess.wav_padding(wav=wav, sr=sampling_rate, frame_period=frame_period, multiple=4) f0, timeaxis, sp, ap = preprocess.world_decompose( wav=wav, fs=sampling_rate, frame_period=frame_period) f0_converted = preprocess.pitch_conversion( f0=f0, mean_log_src=self.dataset_B_mean, std_log_src=self.dataset_B_std, mean_log_target=self.dataset_A_mean, std_log_target=self.dataset_A_std, ) coded_sp = preprocess.world_encode_spectral_envelop( sp=sp, fs=sampling_rate, dim=num_mcep) coded_sp_transposed = coded_sp.T coded_sp_norm = (coded_sp_transposed - self.dataset_B_mean) / self.dataset_B_std coded_sp_norm = np.array([coded_sp_norm]) if flow.cuda.is_available(): coded_sp_norm = flow.tensor(coded_sp_norm).cuda().float() else: coded_sp_norm = flow.tensor(coded_sp_norm).float() coded_sp_converted_norm = self.generator_B2A( coded_sp_norm, flow.ones_like(coded_sp_norm)) coded_sp_converted_norm = coded_sp_converted_norm.cpu().detach( ).numpy() coded_sp_converted_norm = np.squeeze(coded_sp_converted_norm) coded_sp_converted = ( coded_sp_converted_norm * self.dataset_A_std + self.dataset_A_mean) coded_sp_converted = coded_sp_converted.T coded_sp_converted = np.ascontiguousarray( coded_sp_converted).astype(np.double) decoded_sp_converted = preprocess.world_decode_spectral_envelop( coded_sp=coded_sp_converted, fs=sampling_rate) wav_transformed = preprocess.world_speech_synthesis( f0=f0_converted[0], decoded_sp=decoded_sp_converted, ap=ap, fs=sampling_rate, frame_period=frame_period, ) sf.write( os.path.join(output_B_dir, "convert_" + os.path.basename(file)), wav_transformed, sampling_rate, ) def saveModelCheckPoint(self, epoch, PATH): flow.save( self.generator_A2B.state_dict(), os.path.join(PATH, "generator_A2B_%d" % epoch), ) flow.save( self.generator_B2A.state_dict(), os.path.join(PATH, "generator_B2A_%d" % epoch), ) flow.save( self.discriminator_A.state_dict(), os.path.join(PATH, "discriminator_A_%d" % epoch), ) flow.save( self.discriminator_B.state_dict(), os.path.join(PATH, "discriminator_B_%d" % epoch), ) def loadModel(self, PATH): self.generator_A2B.load_state_dict( flow.load(os.path.join(PATH, "generator_A2B"))) self.generator_B2A.load_state_dict( flow.load(os.path.join(PATH, "generator_B2A"))) self.discriminator_A.load_state_dict( flow.load(os.path.join(PATH, "discriminator_A"))) self.discriminator_B.load_state_dict( flow.load(os.path.join(PATH, "discriminator_B")))
transform = transforms.Compose([Normalize(0.5, 0.5), ToTensor()]) batch_size = 2 num_classes = 2 num_attrs = 1 resolutions_to = [4, 8, 8, 16, 16, 32, 32] levels = [1, 1.125, 2, 2.5, 3, 3.5, 4] data_shape = [batch_size, 3, 32, 32] G = Generator(data_shape, use_mask=False, use_attrs=True, num_attrs=num_attrs, latent_size=256) D = Discriminator(data_shape, use_attrs=True, num_attrs=num_attrs, latent_size=256) for res, lev in zip(resolutions_to, levels): dataset = VGGFace2Dataset('./dataset/VGGFACE2/train', res, landmark_info_path, identity_info_path, filtered_list, transform=transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) sample = iter(dataloader).next() # noqa: B305 print(f"resolution: {res}, image: {sample['image'].shape}, \ attr: {sample['attr'].shape}")
# Setup Metrics cty_running_metrics = runningScore(num_classes) model_dict = {} # Setup Model print('building models ...') enc_shared = SharedEncoder().cuda() dclf1 = DomainClassifier().cuda() dclf2 = DomainClassifier().cuda() enc_s = PrivateEncoder(64, private_code_size).cuda() enc_t = PrivateEncoder(64, private_code_size).cuda() dec_s = PrivateDecoder(shared_code_channels, private_code_size).cuda() dec_t = dec_s dis_s2t = Discriminator().cuda() dis_t2s = Discriminator().cuda() model_dict['enc_shared'] = enc_shared model_dict['dclf1'] = dclf1 model_dict['dclf2'] = dclf2 model_dict['enc_s'] = enc_s model_dict['enc_t'] = enc_t model_dict['dec_s'] = dec_s model_dict['dec_t'] = dec_t model_dict['dis_s2t'] = dis_s2t model_dict['dis_t2s'] = dis_t2s enc_shared_opt = optim.SGD(enc_shared.optim_parameters(learning_rate_seg), lr=learning_rate_seg, momentum=0.9,
num_workers=opt.workers) print("Train set size: " + str(len(trainset))) # Whether training form checkpoint if opt.resume: # Load checkpoint. print('==> Resuming from checkpoint..') assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' checkpoint = torch.load('./checkpoint/ckpt.t7') netD = checkpoint['netD'] netG = checkpoint['netG'] start_epoch = checkpoint['epoch'] else: print('==> Building model..') # Create an instance of the nn.module class defined above: netD = Discriminator(trainset, opt.batchSize, reuse=isTrain) netG = Generator(randomInput, opt.batchSize, reuse=True) start_epoch = 0 # For training on GPU, we need to transfer net and data onto the GPU # http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#training-on-gpu if opt.ngpu >= 1: netD = netD.cuda() netD = torch.nn.DataParallel(netD, device_ids=range(torch.cuda.device_count())) netG = netG.cuda() netG = torch.nn.DataParallel(netG, device_ids=range(torch.cuda.device_count())) cudnn.benchmark = True # Optimizers
def train(args, generator: Generator, discriminator: Discriminator, feature_extractor: FeatureExtractor, photo_dataloader, edge_smooth_dataloader, animation_dataloader, checkpoint_dir=None): tb_writter = SummaryWriter() gen_criterion = nn.BCELoss().to(args.device) disc_criterion = nn.BCELoss().to(args.device) content_criterion = nn.L1Loss().to(args.device) gen_optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.adam_beta, 0.999)) disc_optimizer = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.adam_beta, 0.999)) global_step = 0 global_init_step = 0 # The number of steps to skip when loading a checkpoint skipped_step = 0 skipped_init_step = 0 cur_epoch = 0 cur_init_epoch = 0 data_len = min(len(photo_dataloader), len(edge_smooth_dataloader), len(animation_dataloader)) if checkpoint_dir: try: checkpoint_dict = load(checkpoint_dir) generator.load_state_dict(checkpoint_dict['generator']) discriminator.load_state_dict(checkpoint_dict['discriminator']) gen_optimizer.load_state_dict(checkpoint_dict['gen_optimizer']) disc_optimizer.load_state_dict(checkpoint_dict['disc_optimizer']) global_step = checkpoint_dict['global_step'] global_init_step = checkpoint_dict['global_init_step'] cur_epoch = global_step // data_len cur_init_epoch = global_init_step // len(photo_dataloader) skipped_step = global_step % data_len skipped_init_step = global_init_step % len(photo_dataloader) logger.info("Start training with,") logger.info("In initialization step, epoch: %d, step: %d", cur_init_epoch, skipped_init_step) logger.info("In main train step, epoch: %d, step: %d", cur_epoch, skipped_step) except: logger.info("Wrong checkpoint path") t_total = data_len * args.n_epochs t_init_total = len(photo_dataloader) * args.n_init_epoch # Train! logger.info("***** Running training *****") logger.info(" Num photo examples = %d", len(photo_dataloader)) logger.info(" Num edge_smooth examples = %d", len(edge_smooth_dataloader)) logger.info(" Num animation examples = %d", len(animation_dataloader)) logger.info(" Num Epochs = %d", args.n_epochs) logger.info(" Total train batch size = %d", args.batch_size) logger.info(" Total optimization steps = %d", t_total) logger.info(" Num Init Epochs = %d", args.n_init_epoch) logger.info(" Total Init optimization steps = %d", t_init_total) logger.info(" Logging steps = %d", args.logging_steps) logger.info(" Save steps = %d", args.save_steps) init_phase = True try: generator.train() discriminator.train() gloabl_init_loss = 0 # --- Initialization Content loss mb = master_bar(range(cur_init_epoch, args.n_init_epoch)) for init_epoch in mb: epoch_iter = progress_bar(photo_dataloader, parent=mb) for step, (photo, _) in enumerate(epoch_iter): if skipped_init_step > 0: skipped_init_step = -1 continue photo = photo.to(args.device) gen_optimizer.zero_grad() x_features = feature_extractor((photo + 1) / 2).detach() Gx = generator(photo) Gx_features = feature_extractor((Gx + 1) / 2) content_loss = args.content_loss_weight * content_criterion( Gx_features, x_features) content_loss.backward() gen_optimizer.step() gloabl_init_loss += content_loss.item() global_init_step += 1 if args.save_steps > 0 and global_init_step % args.save_steps == 0: logger.info( "Save Initialization Phase, init_epoch: %d, init_step: %d", init_epoch, global_init_step) save(checkpoint_dir, global_step, global_init_step, generator, discriminator, gen_optimizer, disc_optimizer) if args.logging_steps > 0 and global_init_step % args.logging_steps == 0: tb_writter.add_scalar('Initialization Phase/Content Loss', content_loss.item(), global_init_step) tb_writter.add_scalar( 'Initialization Phase/Global Generator Loss', gloabl_init_loss / global_init_step, global_init_step) logger.info( "Initialization Phase, Epoch: %d, Global Step: %d, Content Loss: %.4f", init_epoch, global_init_step, gloabl_init_loss / (global_init_step)) # ----------------------------------------------------- logger.info("Finish Initialization Phase, save model...") save(checkpoint_dir, global_step, global_init_step, generator, discriminator, gen_optimizer, disc_optimizer) init_phase = False global_loss_D = 0 global_loss_G = 0 global_loss_content = 0 mb = master_bar(range(cur_epoch, args.n_epochs)) for epoch in mb: epoch_iter = progress_bar(list( zip(animation_dataloader, edge_smooth_dataloader, photo_dataloader)), parent=mb) for step, ((animation, _), (edge_smoothed, _), (photo, _)) in enumerate(epoch_iter): if skipped_step > 0: skipped_step = -1 continue animation = animation.to(args.device) edge_smoothed = edge_smoothed.to(args.device) photo = photo.to(args.device) disc_optimizer.zero_grad() # --- Train discriminator # ------ Train Discriminator with animation image animation_disc = discriminator(animation) animation_target = torch.ones_like(animation_disc) loss_animation_disc = disc_criterion(animation_disc, animation_target) # ------ Train Discriminator with edge image edge_smoothed_disc = discriminator(edge_smoothed) edge_smoothed_target = torch.zeros_like(edge_smoothed_disc) loss_edge_disc = disc_criterion(edge_smoothed_disc, edge_smoothed_target) # ------ Train Discriminator with generated image generated_image = generator(photo).detach() generated_image_disc = discriminator(generated_image) generated_image_target = torch.zeros_like(generated_image_disc) loss_generated_disc = disc_criterion(generated_image_disc, generated_image_target) loss_disc = loss_animation_disc + loss_edge_disc + loss_generated_disc loss_disc.backward() disc_optimizer.step() global_loss_D += loss_disc.item() # --- Train Generator gen_optimizer.zero_grad() generated_image = generator(photo) generated_image_disc = discriminator(generated_image) generated_image_target = torch.ones_like(generated_image_disc) loss_adv = gen_criterion(generated_image_disc, generated_image_target) # ------ Train Generator with content loss x_features = feature_extractor((photo + 1) / 2).detach() Gx_features = feature_extractor((generated_image + 1) / 2) loss_content = args.content_loss_weight * content_criterion( Gx_features, x_features) loss_gen = loss_adv + loss_content loss_gen.backward() gen_optimizer.step() global_loss_G += loss_adv.item() global_loss_content += loss_content.item() global_step += 1 if args.save_steps > 0 and global_step % args.save_steps == 0: logger.info("Save Training Phase, epoch: %d, step: %d", epoch, global_step) save(checkpoint_dir, global_step, global_init_step, generator, discriminator, gen_optimizer, disc_optimizer) if args.logging_steps > 0 and global_init_step % args.logging_steps == 0: tb_writter.add_scalar('Train Phase/Generator Loss', loss_adv.item(), global_step) tb_writter.add_scalar('Train Phase/Discriminator Loss', loss_disc.item(), global_step) tb_writter.add_scalar('Train Phase/Content Loss', loss_content.item(), global_step) tb_writter.add_scalar('Train Phase/Global Generator Loss', global_loss_G / global_step, global_step) tb_writter.add_scalar( 'Train Phase/Global Discriminator Loss', global_loss_D / global_step, global_step) tb_writter.add_scalar('Train Phase/Global Content Loss', global_loss_content / global_step, global_step) logger.info( "Training Phase, Epoch: %d, Global Step: %d, Disc Loss %.4f, Gen Loss %.4f, Content Loss: %.4f", epoch, global_step, global_loss_D / global_step, global_loss_G / global_step, global_loss_content / global_step) except KeyboardInterrupt: if init_phase: logger.info("KeyboardInterrupt in Initialization Phase!") logger.info("Save models, init_epoch: %d, init_step: %d", init_epoch, global_init_step) else: logger.info("KeyboardInterrupt in Training Phase!") logger.info("Save models, epoch: %d, step: %d", epoch, global_step) save(checkpoint_dir, global_step, global_init_step, generator, discriminator, gen_optimizer, disc_optimizer)
class Solver(object): def __init__(self, data_loader, config): self.config = config self.data_loader = data_loader # Model configurations. self.lambda_cycle = config.lambda_cycle self.lambda_cls = config.lambda_cls self.lambda_identity = config.lambda_identity # Training configurations. self.data_dir = config.data_dir self.test_dir = config.test_dir self.batch_size = config.batch_size self.num_iters = config.num_iters self.num_iters_decay = config.num_iters_decay self.g_lr = config.g_lr self.d_lr = config.d_lr self.c_lr = config.c_lr self.n_critic = config.n_critic self.beta1 = config.beta1 self.beta2 = config.beta2 self.resume_iters = config.resume_iters # Test configurations. self.pretrain_models = config.pretrain_models self.sample_dir = config.sample_dir self.trg_speaker = ast.literal_eval(config.trg_speaker) self.src_speaker = config.src_speaker # Miscellaneous. self.device = flow.device( "cuda:0" if flow.cuda.is_available() else "cpu") self.spk_enc = LabelBinarizer().fit(speakers) # Directories. self.model_save_dir = config.model_save_dir self.result_dir = config.result_dir self.use_gradient_penalty = config.use_gradient_penalty # Step size. self.log_step = config.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step self.lr_update_step = config.lr_update_step # Build the model. self.build_model() def build_model(self): self.G = Generator() self.D = Discriminator() self.C = DomainClassifier() self.g_optimizer = flow.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = flow.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) self.c_optimizer = flow.optim.Adam(self.C.parameters(), self.c_lr, [self.beta1, self.beta2]) self.print_network(self.G, "G") self.print_network(self.D, "D") self.print_network(self.C, "C") self.G.to(self.device) self.D.to(self.device) self.C.to(self.device) def print_network(self, model, name): """Print out the network information.""" num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params)) def update_lr(self, g_lr, d_lr, c_lr): """Decay learning rates of the generator and discriminator and classifier.""" for param_group in self.g_optimizer.param_groups: param_group["lr"] = g_lr for param_group in self.d_optimizer.param_groups: param_group["lr"] = d_lr for param_group in self.c_optimizer.param_groups: param_group["lr"] = c_lr def train(self): # Learning rate cache for decaying. g_lr = self.g_lr d_lr = self.d_lr c_lr = self.c_lr start_iters = 0 if self.resume_iters: pass norm = Normalizer() data_iter = iter(self.data_loader) print("Start training......") start_time = datetime.now() for i in range(start_iters, self.num_iters): # Preprocess input data # Fetch real images and labels. try: x_real, speaker_idx_org, label_org = next(data_iter) except: data_iter = iter(self.data_loader) x_real, speaker_idx_org, label_org = next(data_iter) # Generate target domain labels randomly. rand_idx = flow.randperm(label_org.size(0)) label_trg = label_org[rand_idx] speaker_idx_trg = speaker_idx_org[rand_idx] x_real = x_real.to(self.device) # Original domain one-hot labels. label_org = label_org.to(self.device) # Target domain one-hot labels. label_trg = label_trg.to(self.device) speaker_idx_org = speaker_idx_org.to(self.device) speaker_idx_trg = speaker_idx_trg.to(self.device) # Train the discriminator # Compute loss with real audio frame. CELoss = nn.CrossEntropyLoss() cls_real = self.C(x_real) cls_loss_real = CELoss(input=cls_real, target=speaker_idx_org) self.reset_grad() cls_loss_real.backward() self.c_optimizer.step() # Logging. loss = {} loss["C/C_loss"] = cls_loss_real.item() out_r = self.D(x_real, label_org) # Compute loss with fake audio frame. x_fake = self.G(x_real, label_trg) out_f = self.D(x_fake.detach(), label_trg) d_loss_t = nn.BCEWithLogitsLoss()( input=out_f, target=flow.zeros_like( out_f).float()) + nn.BCEWithLogitsLoss()( input=out_r, target=flow.ones_like(out_r).float()) out_cls = self.C(x_fake) d_loss_cls = CELoss(input=out_cls, target=speaker_idx_trg) # Compute loss for gradient penalty. alpha = flow.rand(x_real.size(0), 1, 1, 1).to(self.device) x_hat = ((alpha * x_real + (1 - alpha) * x_fake).detach().requires_grad_(True)) out_src = self.D(x_hat, label_trg) # TODO: Second-order derivation is not currently supported in oneflow, so gradient penalty cannot be used temporarily. if self.use_gradient_penalty: d_loss_gp = self.gradient_penalty(out_src, x_hat) d_loss = d_loss_t + self.lambda_cls * d_loss_cls + 5 * d_loss_gp else: d_loss = d_loss_t + self.lambda_cls * d_loss_cls self.reset_grad() d_loss.backward() self.d_optimizer.step() loss["D/D_loss"] = d_loss.item() # Train the generator if (i + 1) % self.n_critic == 0: # Original-to-target domain. x_fake = self.G(x_real, label_trg) g_out_src = self.D(x_fake, label_trg) g_loss_fake = nn.BCEWithLogitsLoss()( input=g_out_src, target=flow.ones_like(g_out_src).float()) out_cls = self.C(x_real) g_loss_cls = CELoss(input=out_cls, target=speaker_idx_org) # Target-to-original domain. x_reconst = self.G(x_fake, label_org) g_loss_rec = nn.L1Loss()(x_reconst, x_real) # Original-to-Original domain(identity). x_fake_iden = self.G(x_real, label_org) id_loss = nn.L1Loss()(x_fake_iden, x_real) # Backward and optimize. g_loss = (g_loss_fake + self.lambda_cycle * g_loss_rec + self.lambda_cls * g_loss_cls + self.lambda_identity * id_loss) self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging. loss["G/loss_fake"] = g_loss_fake.item() loss["G/loss_rec"] = g_loss_rec.item() loss["G/loss_cls"] = g_loss_cls.item() loss["G/loss_id"] = id_loss.item() loss["G/g_loss"] = g_loss.item() # Miscellaneous # Print out training information. if (i + 1) % self.log_step == 0: et = datetime.now() - start_time et = str(et)[:-7] log = "Elapsed [{}], Iteration [{}/{}]".format( et, i + 1, self.num_iters) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) # Translate fixed images for debugging. if (i + 1) % self.sample_step == 0: with flow.no_grad(): d, speaker = TestSet(self.test_dir).test_data() target = random.choice( [x for x in speakers if x != speaker]) label_t = self.spk_enc.transform([target])[0] label_t = np.asarray([label_t]) for filename, content in d.items(): f0 = content["f0"] ap = content["ap"] sp_norm_pad = self.pad_coded_sp( content["coded_sp_norm"]) convert_result = [] for start_idx in range( 0, sp_norm_pad.shape[1] - FRAMES + 1, FRAMES): one_seg = sp_norm_pad[:, start_idx:start_idx + FRAMES] one_seg = flow.Tensor(one_seg).to(self.device) one_seg = one_seg.view(1, 1, one_seg.size(0), one_seg.size(1)) l = flow.Tensor(label_t) one_seg = one_seg.to(self.device) l = l.to(self.device) one_set_return = self.G(one_seg, l).detach().cpu().numpy() one_set_return = np.squeeze(one_set_return) one_set_return = norm.backward_process( one_set_return, target) convert_result.append(one_set_return) convert_con = np.concatenate(convert_result, axis=1) convert_con = convert_con[:, 0:content["coded_sp_norm"]. shape[1]] contigu = np.ascontiguousarray(convert_con.T, dtype=np.float64) decoded_sp = decode_spectral_envelope(contigu, SAMPLE_RATE, fft_size=FFTSIZE) f0_converted = norm.pitch_conversion( f0, speaker, target) wav = synthesize(f0_converted, decoded_sp, ap, SAMPLE_RATE) name = f"{speaker}-{target}_iter{i+1}_{filename}" path = os.path.join(self.sample_dir, name) print(f"[save]:{path}") sf.write(path, wav, SAMPLE_RATE) # Save model checkpoints. if (i + 1) % self.model_save_step == 0: G_path = os.path.join(self.model_save_dir, "{}-G".format(i + 1)) D_path = os.path.join(self.model_save_dir, "{}-D".format(i + 1)) C_path = os.path.join(self.model_save_dir, "{}-C".format(i + 1)) flow.save(self.G.state_dict(), G_path) flow.save(self.D.state_dict(), D_path) flow.save(self.C.state_dict(), C_path) print("Saved model checkpoints into {}...".format( self.model_save_dir)) # Decay learning rates. if (i + 1) % self.lr_update_step == 0 and (i + 1) > ( self.num_iters - self.num_iters_decay): g_lr -= self.g_lr / float(self.num_iters_decay) d_lr -= self.d_lr / float(self.num_iters_decay) c_lr -= self.c_lr / float(self.num_iters_decay) self.update_lr(g_lr, d_lr, c_lr) print("Decayed learning rates, g_lr: {}, d_lr: {}.".format( g_lr, d_lr)) def gradient_penalty(self, y, x): """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" weight = flow.ones(y.size()).to(self.device) dydx = flow.autograd.grad(outputs=y, inputs=x, out_grads=weight, retain_graph=True, create_graph=True)[0] dydx = dydx.view(dydx.size(0), -1) dydx_l2norm = flow.sqrt(flow.sum(dydx**2, dim=1)) return flow.mean((dydx_l2norm - 1)**2) def reset_grad(self): """Reset the gradient buffers.""" self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() self.c_optimizer.zero_grad() def restore_model(self, model_save_dir): """Restore the tra,zined generator and discriminator.""" print("Loading the pretrain models...") G_path = os.path.join(model_save_dir, "200000-G") D_path = os.path.join(model_save_dir, "200000-D") C_path = os.path.join(model_save_dir, "200000-C") self.G.load_state_dict(flow.load(G_path)) self.D.load_state_dict(flow.load(D_path)) self.C.load_state_dict(flow.load(C_path)) @staticmethod def pad_coded_sp(coded_sp_norm): f_len = coded_sp_norm.shape[1] if f_len >= FRAMES: pad_length = FRAMES - (f_len - (f_len // FRAMES) * FRAMES) elif f_len < FRAMES: pad_length = FRAMES - f_len sp_norm_pad = np.hstack( (coded_sp_norm, np.zeros((coded_sp_norm.shape[0], pad_length)))) return sp_norm_pad def test(self): """Translate speech using StarGAN .""" # Load the trained generator. self.restore_model(self.pretrain_models) norm = Normalizer() # Set data loader. d, speaker = TestSet(self.test_dir).test_data(self.src_speaker) targets = self.trg_speaker for target in targets: print(target) assert target in speakers label_t = self.spk_enc.transform([target])[0] label_t = np.asarray([label_t]) with flow.no_grad(): for filename, content in d.items(): f0 = content["f0"] ap = content["ap"] sp_norm_pad = self.pad_coded_sp(content["coded_sp_norm"]) convert_result = [] for start_idx in range(0, sp_norm_pad.shape[1] - FRAMES + 1, FRAMES): one_seg = sp_norm_pad[:, start_idx:start_idx + FRAMES] one_seg = flow.Tensor(one_seg).to(self.device) one_seg = one_seg.view(1, 1, one_seg.size(0), one_seg.size(1)) l = flow.Tensor(label_t) one_seg = one_seg.to(self.device) l = l.to(self.device) one_set_return = self.G(one_seg, l).detach().cpu().numpy() one_set_return = np.squeeze(one_set_return) one_set_return = norm.backward_process( one_set_return, target) convert_result.append(one_set_return) convert_con = np.concatenate(convert_result, axis=1) convert_con = convert_con[:, 0:content["coded_sp_norm"]. shape[1]] contigu = np.ascontiguousarray(convert_con.T, dtype=np.float64) decoded_sp = decode_spectral_envelope(contigu, SAMPLE_RATE, fft_size=FFTSIZE) f0_converted = norm.pitch_conversion(f0, speaker, target) wav = synthesize(f0_converted, decoded_sp, ap, SAMPLE_RATE) name = f"{speaker}-{target}_{filename}" path = os.path.join(self.result_dir, name) print(f"[save]:{path}") sf.write(path, wav, SAMPLE_RATE)
NUM_EPOCHS = opt.num_epochs decay_interval = NUM_EPOCHS // 8 decay_indices = [decay_interval, decay_interval*2, decay_interval*4, decay_interval*6] discriminator_cycle = opt.discriminator_cycle generator_cycle = opt.generator_cycle train_set = TrainDatasetFromFolder('data/DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR) val_set = ValDatasetFromFolder('data/DIV2K_valid_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR) train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=8, shuffle=True) val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False) netG = Generator(num_rrdb_blocks=16, scaling_factor=UPSCALE_FACTOR) print('# generator parameters:', sum(param.numel() for param in netG.parameters())) netD = Discriminator(opt.crop_size) print('# discriminator parameters:', sum(param.numel() for param in netD.parameters())) generator_criterion = EGeneratorLoss() discriminator_criterion = torch.nn.BCEWithLogitsLoss() if torch.cuda.is_available(): netG.cuda() netD.cuda() generator_criterion.cuda() discriminator_criterion.cuda() optimizerG = optim.Adam(netG.parameters(), lr=0.0001) optimizerD = optim.Adam(netD.parameters(), lr=0.0001) optimizerG_decay = optim.lr_scheduler.MultiStepLR(optimizerG, decay_indices, gamma=0.5) optimizerD_decay = optim.lr_scheduler.MultiStepLR(optimizerD, decay_indices, gamma=0.5)
elif classname.find('BatchNorm') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) netE = Encoder(ngpu, ndf, nc).cuda() if (device.type == 'cuda') and (ngpu > 1): netE = nn.DataParallel(netE, list(range(ngpu))) netE.apply(weights_init) netG = Decoder(ngpu, num, ngf).cuda() if (device.type == 'cuda') and (ngpu > 1): netG = nn.DataParallel(netG, list(range(ngpu))) netG.apply(weights_init) netD = Discriminator(ngpu, ndf, nc).cuda() if (device.type == 'cuda') and (ngpu > 1): netD = nn.DataParallel(netD, list(range(ngpu))) netD.apply(weights_init) # 损失函数和优化器 criterionE = Encoder_loss criterionG = Decoder_loss criterionD = nn.BCELoss() # 优化器 optimizerE = optim.Adam(netE.parameters(), lr=lr, betas=(beta1, 0.999)) optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999)) real_label = 1 fake_label = 0
def build_model(config, from_style, to_style): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") generator_ab = ResidualGenerator(config.image_size, config.num_residual_blocks).to(device) generator_ba = ResidualGenerator(config.image_size, config.num_residual_blocks).to(device) discriminator_a = Discriminator(config.image_size).to(device) discriminator_b = Discriminator(config.image_size).to(device) generator_ab_param = glob( os.path.join(config.checkpoint_dir, f"{from_style}2{to_style}", f"generator_ab_{config.epoch-1}.pth")) generator_ba_param = glob( os.path.join(config.checkpoint_dir, f"{from_style}2{to_style}", f"generator_ba_{config.epoch-1}.pth")) discriminator_a_param = glob( os.path.join(config.checkpoint_dir, f"{from_style}2{to_style}", f"discriminator_a_{config.epoch-1}.pth")) discriminator_b_param = glob( os.path.join(config.checkpoint_dir, f"{from_style}2{to_style}", f"discriminator_b_{config.epoch-1}.pth")) print(f"[*] Load checkpoint in {config.checkpoint_dir}") if not os.path.exists( os.path.join(config.checkpoint_dir, f"{from_style}2{to_style}")): os.makedirs( os.path.join(config.checkpoint_dir, f"{from_style}2{to_style}")) if len( os.listdir( os.path.join(config.checkpoint_dir, f"{from_style}2{to_style}"))) == 0: print(f"[!] No checkpoint in {config.checkpoint_dir}") generator_ab.apply(weights_init) generator_ba.apply(weights_init) discriminator_a.apply(weights_init) discriminator_b.apply(weights_init) else: generator_ab.load_state_dict( torch.load(generator_ab_param[-1], map_location=device)) generator_ba.load_state_dict( torch.load(generator_ba_param[-1], map_location=device)) discriminator_a.load_state_dict( torch.load(discriminator_a_param[-1], map_location=device)) discriminator_b.load_state_dict( torch.load(discriminator_b_param[-1], map_location=device)) return generator_ab, generator_ba, discriminator_a, discriminator_b
def main(writer): dataset = AnimeDataset(avatar_tag_dat_path, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) G = Generator(noise_size, len(utils.hair) + len(utils.eyes)).to(device) D = Discriminator(len(utils.hair), len(utils.eyes)).to(device) G_optim = torch.optim.Adam(G.parameters(), lr=learning_rate_g, betas=(beta_1, 0.999)) D_optim = torch.optim.Adam(D.parameters(), lr=learning_rate_d, betas=(beta_1, 0.999)) criterion = nn.BCELoss() # training iteration = 0 real_label = torch.ones(batch_size).to(device) # real_label = torch.Tensor(batch_size).uniform_(0.9, 1).to(device) # soft labeling fake_label = torch.zeros(batch_size).to(device) for epoch in range(max_epoch + 1): for i, (real_tag, real_img) in enumerate(data_loader): real_img = real_img.to(device) real_tag = real_tag.to(device) # train D with real images D.zero_grad() real_score, real_predict = D(real_img) real_discrim_loss = criterion(real_score, real_label) real_classifier_loss = criterion(real_predict, real_tag) # train D with fake images z, fake_tag = utils.fake_generator(batch_size, noise_size, device) fake_img = G(z, fake_tag).to(device) fake_score, fake_predict = D(fake_img) fake_discrim_loss = criterion(fake_score, fake_label) discrim_loss = (real_discrim_loss + fake_discrim_loss) * 0.5 classifier_loss = real_classifier_loss * lambda_cls # gradient penalty alpha_size = [1] * real_img.dim() alpha_size[0] = real_img.size(0) alpha = torch.rand(alpha_size).to(device) x_hat = Variable(alpha * real_img.data + (1 - alpha) * (real_img.data + 0.5 * real_img.data.std() * torch.rand(real_img.size()).to(device)), requires_grad=True).to(device) fake_score, fake_tag = D(x_hat) gradients = grad(outputs=fake_score, inputs=x_hat, grad_outputs=torch.ones( fake_score.size()).to(device), create_graph=True, retain_graph=True, only_inputs=True)[0].view(x_hat.size(0), -1) gradient_penalty = lambda_gp * ( (gradients.norm(2, dim=1) - 1)**2).mean() D_loss = discrim_loss + classifier_loss + gradient_penalty D_loss.backward() D_optim.step() # train G G.zero_grad() z, fake_tag = utils.fake_generator(batch_size, noise_size, device) fake_img = G(z, fake_tag).to(device) fake_score, fake_predict = D(fake_img) discrim_loss = criterion(fake_score, real_label) classifier_loss = criterion(fake_predict, fake_tag) * lambda_cls G_loss = discrim_loss + classifier_loss G_loss.backward() G_optim.step() # plot loss curve writer.add_scalar('Loss_D', D_loss.item(), iteration) writer.add_scalar('Loss_G', G_loss.item(), iteration) print('[{}/{}][{}/{}] Iteration: {}'.format( epoch, max_epoch, i, len(data_loader), iteration)) if iteration % interval == interval - 1: fake_img = G(fix_noise, fix_tag) vutils.save_image(utils.denorm(fake_img[:64, :, :, :]), os.path.join( image_path, 'fake_image_{}.png'.format(iteration)), padding=0) vutils.save_image(utils.denorm(real_img[:64, :, :, :]), os.path.join( image_path, 'real_image_{}.png'.format(iteration)), padding=0) grid = vutils.make_grid(utils.denorm(fake_img[:64, :, :, :]), padding=0) writer.add_image('generation results', grid, iteration) iteration += 1 # checkpoint torch.save(G.state_dict(), os.path.join(model_path, 'netG_epoch_{}.pth'.format(epoch))) torch.save(D.state_dict(), os.path.join(model_path, 'netD_epoch_{}.pth'.format(epoch)))
def main(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint_dir", type=str, default="output/ckpt") parser.add_argument("--model_config", type=str, default="model_config.json") parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available") parser.add_argument('--photo_dir', type=str, default="data/photo", help='path to photo datasets.') parser.add_argument('--edge_smooth_dir', type=str, default="data/edge_smooth", help='path to edge_smooth datasets.') parser.add_argument('--target_dir', type=str, default="data/target", help='path to target datasets.') parser.add_argument('--content_loss_weight', type=float, default=10, help='content loss weight') parser.add_argument('--seed', type=int, default=42, help='seed') parser.add_argument('--adam_beta', type=float, default=0.5, help='adam_beta') parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs of training') parser.add_argument('--n_init_epoch', type=int, default=15, help='number of epochs of initializing') parser.add_argument('--batch_size', type=int, default=8, help='size of the batches') parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate') parser.add_argument( '--n_cpu', type=int, default=0, help='number of cpu threads to use during batch generation') parser.add_argument('--logging_steps', type=int, default=50, help="Log every X updates steps.") parser.add_argument('--save_steps', type=int, default=3000, help='Save checkpoint every X updates steps.') args = parser.parse_args() # Setup logging logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) model_config = Config.load(args.model_config) args.device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.n_gpu = torch.cuda.device_count() set_seed(args) logger.warning("device: %s, n_gpu: %s", args.device, args.n_gpu) generator = Generator(model_config).to(args.device) discriminator = Discriminator(model_config).to(args.device) feature_extractor = FeatureExtractor(model_config).to(args.device) transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) photo_dataloader, _ = load_image_dataloader(args.photo_dir, transform, args.batch_size, args.n_cpu) edge_smooth_dataloader, _ = load_image_dataloader(args.edge_smooth_dir, transform, args.batch_size, args.n_cpu) animation_dataloader, _ = load_image_dataloader(args.target_dir, transform, args.batch_size, args.n_cpu) train(args, generator, discriminator, feature_extractor, photo_dataloader, edge_smooth_dataloader, animation_dataloader, args.checkpoint_dir)
val_set = ValDatasetFromFolder('../../../CelebA-HQ-img/', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR) # train_set = TrainDatasetFromFolder('../../../CelebA-HQ-img/', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR) # val_set = ValDatasetFromFolder('../../../CelebA-HQ-img/', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR) train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=opt.batchSize, shuffle=True) val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=opt.testBatchSize, shuffle=False) netG = Generator(UPSCALE_FACTOR).to(device) netD = Discriminator().to(device) # pretrain the Generator and load it # netG.load_state_dict(torch.load('epochs/' + opt.pretrain_path)) netG.load_state_dict(torch.load("./300000_G.pth")) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(0.5, 0.9)) optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(0.5, 0.9)) criterionG = Generator_loss().to(device) # loop over the dataset multiple times for epoch in range(opt.start, NUM_EPOCHS + 1): d_total_loss = 0.0 g_total_loss = 0.0 for i, data in enumerate(train_loader, 0): inputs, labels = data
crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR) val_set = ValDatasetFromFolder('../../../CelebA-HQ-img/', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR) train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=opt.batchSize, shuffle=True) val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=opt.testBatchSize, shuffle=False) netG = Generator(UPSCALE_FACTOR).to(device) netD = Discriminator().to(device) optimizerG = optim.RMSprop(netG.parameters(), lr=opt.lr) criterion = nn.MSELoss() # loop over the dataset multiple times for epoch in range(opt.start, NUM_EPOCHS + 1): running_loss = 0.0 for i, data in enumerate(train_loader, 0): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) # zero the parameter gradients optimizerG.zero_grad() # forward + backward + optimize
class CycleGANTrainr(object): def __init__( self, logf0s_normalization, mcep_normalization, coded_sps_A_norm, coded_sps_B_norm, model_checkpoint, validation_A_dir, output_A_dir, validation_B_dir, output_B_dir, restart_training_at=None, ): self.start_epoch = 0 self.num_epochs = 200000 self.mini_batch_size = 10 self.dataset_A = self.loadPickleFile(coded_sps_A_norm) self.dataset_B = self.loadPickleFile(coded_sps_B_norm) self.device = flow.device( "cuda" if flow.cuda.is_available() else "cpu") # Speech Parameters logf0s_normalization = np.load(logf0s_normalization) self.log_f0s_mean_A = logf0s_normalization["mean_A"] self.log_f0s_std_A = logf0s_normalization["std_A"] self.log_f0s_mean_B = logf0s_normalization["mean_B"] self.log_f0s_std_B = logf0s_normalization["std_B"] mcep_normalization = np.load(mcep_normalization) self.coded_sps_A_mean = mcep_normalization["mean_A"] self.coded_sps_A_std = mcep_normalization["std_A"] self.coded_sps_B_mean = mcep_normalization["mean_B"] self.coded_sps_B_std = mcep_normalization["std_B"] # Generator and Discriminator self.generator_A2B = Generator().to(self.device) self.generator_B2A = Generator().to(self.device) self.discriminator_A = Discriminator().to(self.device) self.discriminator_B = Discriminator().to(self.device) # Loss Functions criterion_mse = flow.nn.MSELoss() # Optimizer g_params = list(self.generator_A2B.parameters()) + list( self.generator_B2A.parameters()) d_params = list(self.discriminator_A.parameters()) + list( self.discriminator_B.parameters()) # Initial learning rates self.generator_lr = 2e-4 self.discriminator_lr = 1e-4 # Learning rate decay self.generator_lr_decay = self.generator_lr / 200000 self.discriminator_lr_decay = self.discriminator_lr / 200000 # Starts learning rate decay from after this many iterations have passed self.start_decay = 10000 self.generator_optimizer = flow.optim.Adam(g_params, lr=self.generator_lr, betas=(0.5, 0.999)) self.discriminator_optimizer = flow.optim.Adam( d_params, lr=self.discriminator_lr, betas=(0.5, 0.999)) # To Load save previously saved models self.modelCheckpoint = model_checkpoint os.makedirs(self.modelCheckpoint, exist_ok=True) # Validation set Parameters self.validation_A_dir = validation_A_dir self.output_A_dir = output_A_dir os.makedirs(self.output_A_dir, exist_ok=True) self.validation_B_dir = validation_B_dir self.output_B_dir = output_B_dir os.makedirs(self.output_B_dir, exist_ok=True) # Storing Discriminatior and Generator Loss self.generator_loss_store = [] self.discriminator_loss_store = [] self.file_name = "log_store_non_sigmoid.txt" def adjust_lr_rate(self, optimizer, name="generator"): if name == "generator": self.generator_lr = max( 0.0, self.generator_lr - self.generator_lr_decay) for param_groups in optimizer.param_groups: param_groups["lr"] = self.generator_lr else: self.discriminator_lr = max( 0.0, self.discriminator_lr - self.discriminator_lr_decay) for param_groups in optimizer.param_groups: param_groups["lr"] = self.discriminator_lr def reset_grad(self): self.generator_optimizer.zero_grad() self.discriminator_optimizer.zero_grad() def train(self): # Training Begins for epoch in range(self.start_epoch, self.num_epochs): start_time_epoch = time.time() # Constants cycle_loss_lambda = 10 identity_loss_lambda = 5 # Preparing Dataset n_samples = len(self.dataset_A) dataset = trainingDataset(datasetA=self.dataset_A, datasetB=self.dataset_B, n_frames=128) train_loader = flow.utils.data.DataLoader( dataset=dataset, batch_size=self.mini_batch_size, shuffle=True, drop_last=False, ) pbar = tqdm(enumerate(train_loader)) for i, (real_A, real_B) in enumerate(train_loader): num_iterations = (n_samples // self.mini_batch_size) * epoch + i if num_iterations > 10000: identity_loss_lambda = 0 if num_iterations > self.start_decay: self.adjust_lr_rate(self.generator_optimizer, name="generator") self.adjust_lr_rate(self.generator_optimizer, name="discriminator") real_A = real_A.to(self.device).float() real_B = real_B.to(self.device).float() # Generator Loss function fake_B = self.generator_A2B(real_A) cycle_A = self.generator_B2A(fake_B) fake_A = self.generator_B2A(real_B) cycle_B = self.generator_A2B(fake_A) identity_A = self.generator_B2A(real_A) identity_B = self.generator_A2B(real_B) d_fake_A = self.discriminator_A(fake_A) d_fake_B = self.discriminator_B(fake_B) # for the second step adverserial loss d_fake_cycle_A = self.discriminator_A(cycle_A) d_fake_cycle_B = self.discriminator_B(cycle_B) # Generator Cycle loss cycleLoss = flow.mean(flow.abs(real_A - cycle_A)) + flow.mean( flow.abs(real_B - cycle_B)) # Generator Identity Loss identiyLoss = flow.mean( flow.abs(real_A - identity_A)) + flow.mean( flow.abs(real_B - identity_B)) # Generator Loss generator_loss_A2B = flow.mean((1 - d_fake_B)**2) generator_loss_B2A = flow.mean((1 - d_fake_A)**2) # Total Generator Loss generator_loss = (generator_loss_A2B + generator_loss_B2A + cycle_loss_lambda * cycleLoss + identity_loss_lambda * identiyLoss) self.generator_loss_store.append(generator_loss.item()) # Backprop for Generator self.reset_grad() generator_loss.backward() self.generator_optimizer.step() # Discriminator Feed Forward d_real_A = self.discriminator_A(real_A) d_real_B = self.discriminator_B(real_B) generated_A = self.generator_B2A(real_B) d_fake_A = self.discriminator_A(generated_A) # for the second step adverserial loss cycled_B = self.generator_A2B(generated_A) d_cycled_B = self.discriminator_B(cycled_B) generated_B = self.generator_A2B(real_A) d_fake_B = self.discriminator_B(generated_B) # for the second step adverserial loss cycled_A = self.generator_B2A(generated_B) d_cycled_A = self.discriminator_A(cycled_A) # Loss Functions d_loss_A_real = flow.mean((1 - d_real_A)**2) d_loss_A_fake = flow.mean((0 - d_fake_A)**2) d_loss_A = (d_loss_A_real + d_loss_A_fake) / 2.0 d_loss_B_real = flow.mean((1 - d_real_B)**2) d_loss_B_fake = flow.mean((0 - d_fake_B)**2) d_loss_B = (d_loss_B_real + d_loss_B_fake) / 2.0 # the second step adverserial loss d_loss_A_cycled = flow.mean((0 - d_cycled_A)**2) d_loss_B_cycled = flow.mean((0 - d_cycled_B)**2) d_loss_A_2nd = (d_loss_A_real + d_loss_A_cycled) / 2.0 d_loss_B_2nd = (d_loss_B_real + d_loss_B_cycled) / 2.0 # Final Loss for discriminator with the second step adverserial loss d_loss = (d_loss_A + d_loss_B) / 2.0 + (d_loss_A_2nd + d_loss_B_2nd) / 2.0 self.discriminator_loss_store.append(d_loss.item()) # Backprop for Discriminator self.reset_grad() d_loss.backward() self.discriminator_optimizer.step() if (i + 1) % 2 == 0: pbar.set_description( "Iter:{} Generator Loss:{:.4f} Discrimator Loss:{:.4f} GA2B:{:.4f} GB2A:{:.4f} G_id:{:.4f} G_cyc:{:.4f} D_A:{:.4f} D_B:{:.4f}" .format( num_iterations, generator_loss.item(), d_loss.item(), generator_loss_A2B, generator_loss_B2A, identiyLoss, cycleLoss, d_loss_A, d_loss_B, )) if epoch % 2000 == 0 and epoch != 0: end_time = time.time() store_to_file = "Epoch: {} Generator Loss: {:.4f} Discriminator Loss: {}, Time: {:.2f}\n\n".format( epoch, generator_loss.item(), d_loss.item(), end_time - start_time_epoch, ) self.store_to_file(store_to_file) print( "Epoch: {} Generator Loss: {:.4f} Discriminator Loss: {}, Time: {:.2f}\n\n" .format( epoch, generator_loss.item(), d_loss.item(), end_time - start_time_epoch, )) # Save the Entire model print("Saving model Checkpoint ......") store_to_file = "Saving model Checkpoint ......" self.store_to_file(store_to_file) self.saveModelCheckPoint(epoch, self.modelCheckpoint) print("Model Saved!") if epoch % 2000 == 0 and epoch != 0: # Validation Set validation_start_time = time.time() self.validation_for_A_dir() self.validation_for_B_dir() validation_end_time = time.time() store_to_file = "Time taken for validation Set: {}".format( validation_end_time - validation_start_time) self.store_to_file(store_to_file) print("Time taken for validation Set: {}".format( validation_end_time - validation_start_time)) def infer(self, PATH="sample"): num_mcep = 36 sampling_rate = 16000 frame_period = 5.0 n_frames = 128 infer_A_dir = PATH output_A_dir = PATH for file in os.listdir(infer_A_dir): filePath = os.path.join(infer_A_dir, file) wav, _ = librosa.load(filePath, sr=sampling_rate, mono=True) wav = preprocess.wav_padding(wav=wav, sr=sampling_rate, frame_period=frame_period, multiple=4) f0, timeaxis, sp, ap = preprocess.world_decompose( wav=wav, fs=sampling_rate, frame_period=frame_period) f0_converted = preprocess.pitch_conversion( f0=f0, mean_log_src=self.log_f0s_mean_A, std_log_src=self.log_f0s_std_A, mean_log_target=self.log_f0s_mean_B, std_log_target=self.log_f0s_std_B, ) coded_sp = preprocess.world_encode_spectral_envelop( sp=sp, fs=sampling_rate, dim=num_mcep) coded_sp_transposed = coded_sp.T coded_sp_norm = (coded_sp_transposed - self.coded_sps_A_mean) / self.coded_sps_A_std coded_sp_norm = np.array([coded_sp_norm]) if flow.cuda.is_available(): coded_sp_norm = flow.tensor(coded_sp_norm).cuda().float() else: coded_sp_norm = flow.tensor(coded_sp_norm).float() coded_sp_converted_norm = self.generator_A2B(coded_sp_norm) coded_sp_converted_norm = coded_sp_converted_norm.cpu().detach( ).numpy() coded_sp_converted_norm = np.squeeze(coded_sp_converted_norm) coded_sp_converted = ( coded_sp_converted_norm * self.coded_sps_B_std + self.coded_sps_B_mean) coded_sp_converted = coded_sp_converted.T coded_sp_converted = np.ascontiguousarray(coded_sp_converted) decoded_sp_converted = preprocess.world_decode_spectral_envelop( coded_sp=coded_sp_converted, fs=sampling_rate) wav_transformed = preprocess.world_speech_synthesis( f0=f0_converted, decoded_sp=decoded_sp_converted, ap=ap, fs=sampling_rate, frame_period=frame_period, ) sf.write( os.path.join(output_A_dir, "convert_" + os.path.basename(file)), wav_transformed, sampling_rate, ) def validation_for_A_dir(self): num_mcep = 36 sampling_rate = 16000 frame_period = 5.0 n_frames = 128 validation_A_dir = self.validation_A_dir output_A_dir = self.output_A_dir print("Generating Validation Data B from A...") for file in os.listdir(validation_A_dir): filePath = os.path.join(validation_A_dir, file) wav, _ = librosa.load(filePath, sr=sampling_rate, mono=True) wav = preprocess.wav_padding(wav=wav, sr=sampling_rate, frame_period=frame_period, multiple=4) f0, timeaxis, sp, ap = preprocess.world_decompose( wav=wav, fs=sampling_rate, frame_period=frame_period) f0_converted = preprocess.pitch_conversion( f0=f0, mean_log_src=self.log_f0s_mean_A, std_log_src=self.log_f0s_std_A, mean_log_target=self.log_f0s_mean_B, std_log_target=self.log_f0s_std_B, ) coded_sp = preprocess.world_encode_spectral_envelop( sp=sp, fs=sampling_rate, dim=num_mcep) coded_sp_transposed = coded_sp.T coded_sp_norm = (coded_sp_transposed - self.coded_sps_A_mean) / self.coded_sps_A_std coded_sp_norm = np.array([coded_sp_norm]) if flow.cuda.is_available(): coded_sp_norm = flow.tensor(coded_sp_norm).cuda().float() else: coded_sp_norm = flow.tensor(coded_sp_norm).float() coded_sp_converted_norm = self.generator_A2B(coded_sp_norm) coded_sp_converted_norm = coded_sp_converted_norm.cpu().detach( ).numpy() coded_sp_converted_norm = np.squeeze(coded_sp_converted_norm) coded_sp_converted = ( coded_sp_converted_norm * self.coded_sps_B_std + self.coded_sps_B_mean) coded_sp_converted = coded_sp_converted.T coded_sp_converted = np.ascontiguousarray(coded_sp_converted) decoded_sp_converted = preprocess.world_decode_spectral_envelop( coded_sp=coded_sp_converted, fs=sampling_rate) wav_transformed = preprocess.world_speech_synthesis( f0=f0_converted, decoded_sp=decoded_sp_converted, ap=ap, fs=sampling_rate, frame_period=frame_period, ) sf.write( os.path.join(output_A_dir, os.path.basename(file)), wav_transformed, sampling_rate, ) def validation_for_B_dir(self): num_mcep = 36 sampling_rate = 16000 frame_period = 5.0 n_frames = 128 validation_B_dir = self.validation_B_dir output_B_dir = self.output_B_dir print("Generating Validation Data A from B...") for file in os.listdir(validation_B_dir): filePath = os.path.join(validation_B_dir, file) wav, _ = librosa.load(filePath, sr=sampling_rate, mono=True) wav = preprocess.wav_padding(wav=wav, sr=sampling_rate, frame_period=frame_period, multiple=4) f0, timeaxis, sp, ap = preprocess.world_decompose( wav=wav, fs=sampling_rate, frame_period=frame_period) f0_converted = preprocess.pitch_conversion( f0=f0, mean_log_src=self.log_f0s_mean_B, std_log_src=self.log_f0s_std_B, mean_log_target=self.log_f0s_mean_A, std_log_target=self.log_f0s_std_A, ) coded_sp = preprocess.world_encode_spectral_envelop( sp=sp, fs=sampling_rate, dim=num_mcep) coded_sp_transposed = coded_sp.T coded_sp_norm = (coded_sp_transposed - self.coded_sps_B_mean) / self.coded_sps_B_std coded_sp_norm = np.array([coded_sp_norm]) if flow.cuda.is_available(): coded_sp_norm = flow.tensor(coded_sp_norm).cuda().float() else: coded_sp_norm = flow.tensor(coded_sp_norm).float() coded_sp_converted_norm = self.generator_B2A(coded_sp_norm) coded_sp_converted_norm = coded_sp_converted_norm.cpu().detach( ).numpy() coded_sp_converted_norm = np.squeeze(coded_sp_converted_norm) coded_sp_converted = ( coded_sp_converted_norm * self.coded_sps_A_std + self.coded_sps_A_mean) coded_sp_converted = coded_sp_converted.T coded_sp_converted = np.ascontiguousarray(coded_sp_converted) decoded_sp_converted = preprocess.world_decode_spectral_envelop( coded_sp=coded_sp_converted, fs=sampling_rate) wav_transformed = preprocess.world_speech_synthesis( f0=f0_converted, decoded_sp=decoded_sp_converted, ap=ap, fs=sampling_rate, frame_period=frame_period, ) sf.write( os.path.join(output_B_dir, os.path.basename(file)), wav_transformed, sampling_rate, ) def savePickle(self, variable, fileName): with open(fileName, "wb") as f: pickle.dump(variable, f) def loadPickleFile(self, fileName): with open(fileName, "rb") as f: return pickle.load(f) def store_to_file(self, doc): doc = doc + "\n" with open(self.file_name, "a") as myfile: myfile.write(doc) def saveModelCheckPoint(self, epoch, PATH): flow.save( self.generator_A2B.state_dict(), os.path.join(PATH, "generator_A2B_%d" % epoch), ) flow.save( self.generator_B2A.state_dict(), os.path.join(PATH, "generator_B2A_%d" % epoch), ) flow.save( self.discriminator_A.state_dict(), os.path.join(PATH, "discriminator_A_%d" % epoch), ) flow.save( self.discriminator_B.state_dict(), os.path.join(PATH, "discriminator_B_%d" % epoch), ) def loadModel(self, PATH): self.generator_A2B.load_state_dict( flow.load(os.path.join(PATH, "generator_A2B"))) self.generator_B2A.load_state_dict( flow.load(os.path.join(PATH, "generator_B2A"))) self.discriminator_A.load_state_dict( flow.load(os.path.join(PATH, "discriminator_A"))) self.discriminator_B.load_state_dict( flow.load(os.path.join(PATH, "discriminator_B")))
val_set = ValDatasetFromFolder('data/DIV2K_valid_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR) train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=4, shuffle=True) val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False) netG = Generator(num_rrdb_blocks=16, scaling_factor=UPSCALE_FACTOR) print('# generator parameters:', sum(param.numel() for param in netG.parameters())) netD = Discriminator(opt.crop_size) print('# discriminator parameters:', sum(param.numel() for param in netD.parameters())) netD_HF = Discriminator(image_size=dwt_size) generator_criterion = EGeneratorLoss() discriminator_criterion = torch.nn.BCEWithLogitsLoss() if torch.cuda.is_available(): netG.cuda() netD.cuda() netD_HF.cuda() generator_criterion.cuda() discriminator_criterion.cuda() optimizerG = optim.Adam(netG.parameters(), lr=0.0001)
def __init__( self, logf0s_normalization, mcep_normalization, coded_sps_A_norm, coded_sps_B_norm, model_checkpoint, validation_A_dir, output_A_dir, validation_B_dir, output_B_dir, restart_training_at=None, ): self.start_epoch = 0 self.num_epochs = 200000 self.mini_batch_size = 10 self.dataset_A = self.loadPickleFile(coded_sps_A_norm) self.dataset_B = self.loadPickleFile(coded_sps_B_norm) self.device = flow.device( "cuda" if flow.cuda.is_available() else "cpu") # Speech Parameters logf0s_normalization = np.load(logf0s_normalization) self.log_f0s_mean_A = logf0s_normalization["mean_A"] self.log_f0s_std_A = logf0s_normalization["std_A"] self.log_f0s_mean_B = logf0s_normalization["mean_B"] self.log_f0s_std_B = logf0s_normalization["std_B"] mcep_normalization = np.load(mcep_normalization) self.coded_sps_A_mean = mcep_normalization["mean_A"] self.coded_sps_A_std = mcep_normalization["std_A"] self.coded_sps_B_mean = mcep_normalization["mean_B"] self.coded_sps_B_std = mcep_normalization["std_B"] # Generator and Discriminator self.generator_A2B = Generator().to(self.device) self.generator_B2A = Generator().to(self.device) self.discriminator_A = Discriminator().to(self.device) self.discriminator_B = Discriminator().to(self.device) # Loss Functions criterion_mse = flow.nn.MSELoss() # Optimizer g_params = list(self.generator_A2B.parameters()) + list( self.generator_B2A.parameters()) d_params = list(self.discriminator_A.parameters()) + list( self.discriminator_B.parameters()) # Initial learning rates self.generator_lr = 2e-4 self.discriminator_lr = 1e-4 # Learning rate decay self.generator_lr_decay = self.generator_lr / 200000 self.discriminator_lr_decay = self.discriminator_lr / 200000 # Starts learning rate decay from after this many iterations have passed self.start_decay = 10000 self.generator_optimizer = flow.optim.Adam(g_params, lr=self.generator_lr, betas=(0.5, 0.999)) self.discriminator_optimizer = flow.optim.Adam( d_params, lr=self.discriminator_lr, betas=(0.5, 0.999)) # To Load save previously saved models self.modelCheckpoint = model_checkpoint os.makedirs(self.modelCheckpoint, exist_ok=True) # Validation set Parameters self.validation_A_dir = validation_A_dir self.output_A_dir = output_A_dir os.makedirs(self.output_A_dir, exist_ok=True) self.validation_B_dir = validation_B_dir self.output_B_dir = output_B_dir os.makedirs(self.output_B_dir, exist_ok=True) # Storing Discriminatior and Generator Loss self.generator_loss_store = [] self.discriminator_loss_store = [] self.file_name = "log_store_non_sigmoid.txt"