def build_model(self): hps = self.hps ns = self.hps.ns emb_size = self.hps.emb_size self.Encoder = Encoder(ns=ns, dp=hps.enc_dp) self.Decoder = Decoder(ns=ns, c_a=hps.n_speakers, emb_size=emb_size) self.Generator = Decoder(ns=ns, c_a=hps.n_speakers, emb_size=emb_size) self.LatentDiscriminator = LatentDiscriminator(ns=ns, dp=hps.dis_dp) self.PatchDiscriminator = PatchDiscriminator(ns=ns, n_class=hps.n_speakers) if torch.cuda.is_available(): self.Encoder.cuda() self.Decoder.cuda() self.Generator.cuda() self.LatentDiscriminator.cuda() self.PatchDiscriminator.cuda() betas = (0.5, 0.9) params = list(self.Encoder.parameters()) + list( self.Decoder.parameters()) self.ae_opt = optim.Adam(params, lr=self.hps.lr, betas=betas) self.gen_opt = optim.Adam(self.Generator.parameters(), lr=self.hps.lr, betas=betas) self.lat_opt = optim.Adam(self.LatentDiscriminator.parameters(), lr=self.hps.lr, betas=betas) self.patch_opt = optim.Adam(self.PatchDiscriminator.parameters(), lr=self.hps.lr, betas=betas)
def build_model(self): hps = self.hps ns = self.hps.ns emb_size = self.hps.emb_size self.Encoder = cc(Encoder(ns=ns, dp=hps.enc_dp)) self.Decoder = cc(Decoder(ns=ns, c_a=hps.n_speakers, emb_size=emb_size)) self.Generator = cc( Decoder(ns=ns, c_a=hps.n_speakers, emb_size=emb_size)) self.SpeakerClassifier = cc( SpeakerClassifier(ns=ns, n_class=hps.n_speakers, dp=hps.dis_dp)) self.PatchDiscriminator = cc( nn.DataParallel(PatchDiscriminator(ns=ns, n_class=hps.n_speakers))) betas = (0.5, 0.9) params = list(self.Encoder.parameters()) + list( self.Decoder.parameters()) self.ae_opt = optim.Adam(params, lr=self.hps.lr, betas=betas) self.clf_opt = optim.Adam(self.SpeakerClassifier.parameters(), lr=self.hps.lr, betas=betas) self.gen_opt = optim.Adam(self.Generator.parameters(), lr=self.hps.lr, betas=betas) self.patch_opt = optim.Adam(self.PatchDiscriminator.parameters(), lr=self.hps.lr, betas=betas)
def build_model(self, wavenet_mel): hps = self.hps ns = self.hps.ns emb_size = self.hps.emb_size c = 80 if wavenet_mel else 513 patch_classify_kernel = (3, 4) if wavenet_mel else (17, 4) self.Encoder = cc(Encoder(c_in=c, ns=ns, dp=hps.enc_dp)) self.Decoder = cc( Decoder(c_out=c, ns=ns, c_a=hps.n_speakers, emb_size=emb_size)) self.Generator = cc( Decoder(c_out=c, ns=ns, c_a=hps.n_speakers, emb_size=emb_size)) self.SpeakerClassifier = cc( SpeakerClassifier(ns=ns, n_class=hps.n_speakers, dp=hps.dis_dp)) self.PatchDiscriminator = cc( nn.DataParallel( PatchDiscriminator( ns=ns, n_class=hps.n_speakers, classify_kernel_size=patch_classify_kernel))) betas = (0.5, 0.9) params = list(self.Encoder.parameters()) + list( self.Decoder.parameters()) self.ae_opt = optim.Adam(params, lr=self.hps.lr, betas=betas) self.clf_opt = optim.Adam(self.SpeakerClassifier.parameters(), lr=self.hps.lr, betas=betas) self.gen_opt = optim.Adam(self.Generator.parameters(), lr=self.hps.lr, betas=betas) self.patch_opt = optim.Adam(self.PatchDiscriminator.parameters(), lr=self.hps.lr, betas=betas)
transforms.RandomHorizontalFlip(), transforms.RandomGrayscale(), transforms.ToTensor(), ]) train_set = DS(args.root, train_tf) iterator_train = iter(data.DataLoader( train_set, batch_size=args.batch_size, sampler=InfiniteSampler(len(train_set)), num_workers=args.n_threads)) print(len(train_set)) g_model = InpaintNet().to(device) fd_model = FeaturePatchDiscriminator().to(device) pd_model = PatchDiscriminator().to(device) l1 = nn.L1Loss().to(device) cons = ConsistencyLoss().to(device) start_iter = 0 g_optimizer = torch.optim.Adam( g_model.parameters(), args.lr, (args.b1, args.b2)) fd_optimizer = torch.optim.Adam( fd_model.parameters(), args.lr, (args.b1, args.b2)) pd_optimizer = torch.optim.Adam( pd_model.parameters(), args.lr, (args.b1, args.b2)) if args.resume:
WORDS_NUM=args.WORDS_NUM, BRANCH_NUM=args.BRANCH_NUM, transform=train_tf) assert dataset_train train_set = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, drop_last=True, shuffle=True, num_workers=args.n_threads) print(len(train_set)) ixtoword_train = dataset_train.ixtoword g_model = InpaintNet().to(device) pd_model = PatchDiscriminator().to(device) l1 = nn.L1Loss().to(device) start_epoch = 0 g_optimizer_t = torch.optim.Adam(g_model.parameters(), args.lr, (args.b1, args.b2)) pd_optimizer_t = torch.optim.Adam(pd_model.parameters(), args.lr, (args.b1, args.b2)) if args.resume: g_checkpoint = torch.load(f'{args.save_dir}/ckpt/G_{args.resume}.pth', map_location=device) g_model.load_state_dict(g_checkpoint) pd_checkpoint = torch.load(f'{args.save_dir}/ckpt/PD_{args.resume}.pth', map_location=device) pd_model.load_state_dict(pd_checkpoint)
def build_model(self): hps = self.hps ns = self.hps.ns emb_size = self.hps.emb_size betas = (0.5, 0.9) #---stage one---# self.Encoder = cc( Encoder(ns=ns, dp=hps.enc_dp, emb_size=emb_size, seg_len=hps.seg_len, one_hot=self.one_hot, binary_output=self.binary_output, binary_ver=self.binary_ver)) self.Decoder = cc( Decoder(ns=ns, c_in=emb_size, c_h=emb_size, c_a=hps.n_speakers, seg_len=hps.seg_len, inp_emb=self.one_hot or self.binary_output)) self.SpeakerClassifier = cc( SpeakerClassifier( ns=ns, c_in=emb_size if not self.binary_output else emb_size * emb_size, c_h=emb_size, n_class=hps.n_speakers, dp=hps.dis_dp, seg_len=hps.seg_len)) #---stage one opts---# params = list(self.Encoder.parameters()) + \ list(self.Decoder.parameters()) self.ae_opt = optim.Adam(params, lr=self.hps.lr, betas=betas) self.clf_opt = optim.Adam(self.SpeakerClassifier.parameters(), lr=self.hps.lr, betas=betas) #---stage two---# self.Generator = cc( Decoder(ns=ns, c_in=emb_size, c_h=emb_size, c_a=hps.n_speakers if not self.targeted_G else hps.n_target_speakers)) self.PatchDiscriminator = cc( nn.DataParallel( PatchDiscriminator( ns=ns, n_class=hps.n_speakers if not self.targeted_G else hps.n_target_speakers, seg_len=hps.seg_len))) #---stage two opts---# self.gen_opt = optim.Adam(self.Generator.parameters(), lr=self.hps.lr, betas=betas) self.patch_opt = optim.Adam(self.PatchDiscriminator.parameters(), lr=self.hps.lr, betas=betas)
class Solver(object): def __init__(self, hps, data_loader, log_dir='./log/'): self.hps = hps self.data_loader = data_loader self.model_kept = [] self.max_keep = 20 self.build_model() self.logger = Logger(log_dir) def build_model(self): hps = self.hps ns = self.hps.ns emb_size = self.hps.emb_size self.Encoder = Encoder(ns=ns, dp=hps.enc_dp) self.Decoder = Decoder(ns=ns, c_a=hps.n_speakers, emb_size=emb_size) self.Generator = Decoder(ns=ns, c_a=hps.n_speakers, emb_size=emb_size) self.LatentDiscriminator = LatentDiscriminator(ns=ns, dp=hps.dis_dp) self.PatchDiscriminator = PatchDiscriminator(ns=ns, n_class=hps.n_speakers) if torch.cuda.is_available(): self.Encoder.cuda() self.Decoder.cuda() self.Generator.cuda() self.LatentDiscriminator.cuda() self.PatchDiscriminator.cuda() betas = (0.5, 0.9) params = list(self.Encoder.parameters()) + list( self.Decoder.parameters()) self.ae_opt = optim.Adam(params, lr=self.hps.lr, betas=betas) self.gen_opt = optim.Adam(self.Generator.parameters(), lr=self.hps.lr, betas=betas) self.lat_opt = optim.Adam(self.LatentDiscriminator.parameters(), lr=self.hps.lr, betas=betas) self.patch_opt = optim.Adam(self.PatchDiscriminator.parameters(), lr=self.hps.lr, betas=betas) def save_model(self, model_path, iteration, enc_only=True): if not enc_only: all_model = { 'encoder': self.Encoder.state_dict(), 'decoder': self.Decoder.state_dict(), 'generator': self.Generator.state_dict(), 'latent_discriminator': self.LatentDiscriminator.state_dict(), 'patch_discriminator': self.PatchDiscriminator.state_dict(), } else: all_model = { 'encoder': self.Encoder.state_dict(), 'decoder': self.Decoder.state_dict(), 'generator': self.Generator.state_dict(), } new_model_path = '{}-{}'.format(model_path, iteration) with open(new_model_path, 'wb') as f_out: torch.save(all_model, f_out) self.model_kept.append(new_model_path) if len(self.model_kept) >= self.max_keep: os.remove(self.model_kept[0]) self.model_kept.pop(0) def load_model(self, model_path, enc_only=True): print('load model from {}'.format(model_path)) with open(model_path, 'rb') as f_in: all_model = torch.load(f_in) self.Encoder.load_state_dict(all_model['encoder']) self.Decoder.load_state_dict(all_model['decoder']) #self.Genrator.load_state_dict(all_model['generator']) if not enc_only: self.LatentDiscriminator.load_state_dict( all_model['latent_discriminator']) self.PatchDiscriminator.load_state_dict( all_model['patch_discriminator']) def set_eval(self): self.Encoder.eval() self.Decoder.eval() self.Generator.eval() #self.LatentDiscriminator.eval() def test_step(self, x, c): self.set_eval() x = to_var(x).permute(0, 2, 1) enc = self.Encoder(x) x_tilde = self.Decoder(enc, c) return x_tilde.data.cpu().numpy() def permute_data(self, data): C = [to_var(c, requires_grad=False) for c in data[:2]] X = [to_var(x).permute(0, 2, 1) for x in data[2:]] return C, X def sample_c(self, size): c_sample = Variable(torch.multinomial(torch.ones(8), num_samples=size, replacement=True), requires_grad=False) c_sample = c_sample.cuda() if torch.cuda.is_available() else c_sample return c_sample def cal_acc(self, logits, y_true): _, ind = torch.max(logits, dim=1) acc = torch.sum( (ind == y_true).type(torch.FloatTensor)) / y_true.size(0) return acc def encode_step(self, *args): enc_list = [] for x in args: enc = self.Encoder(x) enc_list.append(enc) return tuple(enc_list) def decode_step(self, enc, c): x_tilde = self.Decoder(enc, c) return x_tilde def latent_discriminate_step(self, enc_i_t, enc_i_tk, enc_i_prime, enc_j, is_dis=True): same_pair = torch.cat([enc_i_t, enc_i_tk], dim=1) diff_pair = torch.cat([enc_i_prime, enc_j], dim=1) if is_dis: same_val = self.LatentDiscriminator(same_pair) diff_val = self.LatentDiscriminator(diff_pair) w_dis = torch.mean(same_val - diff_val) gp = calculate_gradients_penalty(self.LatentDiscriminator, same_pair, diff_pair) return w_dis, gp else: diff_val = self.LatentDiscriminator(diff_pair) loss_adv = -torch.mean(diff_val) return loss_adv def patch_discriminate_step(self, x, x_tilde, cal_gp=True): # w-distance D_real, real_logits = self.PatchDiscriminator(x, classify=True) D_fake, fake_logits = self.PatchDiscriminator(x_tilde, classify=True) w_dis = torch.mean(D_real - D_fake) if cal_gp: gp = calculate_gradients_penalty(self.PatchDiscriminator, x, x_tilde) return w_dis, real_logits, fake_logits, gp else: return w_dis, real_logits, fake_logits # backup #def classify(): # # aux clssify loss # criterion = nn.NLLLoss() # c_loss = criterion(real_logits, c) + criterion(fake_logits, c_sample) # real_acc = self.cal_acc(real_logits, c) # fake_acc = self.cal_acc(fake_logits, c_sample) def train(self, model_path, flag='train'): # load hyperparams hps = self.hps for iteration in range(hps.iters): # calculate current alpha if iteration + 1 < hps.lat_sched_iters and iteration >= hps.enc_pretrain_iters: current_alpha = hps.alpha_enc * ( iteration + 1 - hps.enc_pretrain_iters) / ( hps.lat_sched_iters - hps.enc_pretrain_iters) else: current_alpha = 0 if iteration >= hps.enc_pretrain_iters: n_latent_steps = hps.n_latent_steps \ if iteration > hps.enc_pretrain_iters else hps.dis_pretrain_iters for step in range(n_latent_steps): #===================== Train latent discriminator =====================# data = next(self.data_loader) (c_i, c_j), (x_i_t, x_i_tk, x_i_prime, x_j) = self.permute_data(data) # encode enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step( x_i_t, x_i_tk, x_i_prime, x_j) # latent discriminate latent_w_dis, latent_gp = self.latent_discriminate_step( enc_i_t, enc_i_tk, enc_i_prime, enc_j) lat_loss = -hps.alpha_dis * latent_w_dis + hps.lambda_ * latent_gp reset_grad([self.LatentDiscriminator]) lat_loss.backward() grad_clip([self.LatentDiscriminator], self.hps.max_grad_norm) self.lat_opt.step() # print info info = { f'{flag}/D_latent_w_dis': latent_w_dis.data[0], f'{flag}/latent_gp': latent_gp.data[0], } slot_value = (step, iteration + 1, hps.iters) + \ tuple([value for value in info.values()]) log = 'lat_D-%d:[%06d/%06d], w_dis=%.3f, gp=%.2f' print(log % slot_value) for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration) # two stage training if iteration >= hps.patch_start_iter: for step in range(hps.n_patch_steps): #===================== Train patch discriminator =====================# data = next(self.data_loader) (c_i, _), (x_i_t, _, _, _) = self.permute_data(data) # encode enc_i_t, = self.encode_step(x_i_t) c_sample = self.sample_c(x_i_t.size(0)) x_tilde = self.decode_step(enc_i_t, c_i) # Aux classify loss patch_w_dis, real_logits, fake_logits, patch_gp = \ self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=True) patch_loss = -hps.beta_dis * patch_w_dis + hps.lambda_ * patch_gp + hps.beta_clf * c_loss reset_grad([self.PatchDiscriminator]) patch_loss.backward() grad_clip([self.PatchDiscriminator], self.hps.max_grad_norm) self.patch_opt.step() # print info info = { f'{flag}/D_patch_w_dis': patch_w_dis.data[0], f'{flag}/patch_gp': patch_gp.data[0], f'{flag}/c_loss': c_loss.data[0], f'{flag}/real_acc': real_acc, f'{flag}/fake_acc': fake_acc, } slot_value = (step, iteration + 1, hps.iters) + \ tuple([value for value in info.values()]) log = 'patch_D-%d:[%06d/%06d], w_dis=%.3f, gp=%.2f, c_loss=%.3f, real_acc=%.2f, fake_acc=%.2f' print(log % slot_value) for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration) #===================== Train G =====================# data = next(self.data_loader) (c_i, c_j), (x_i_t, x_i_tk, x_i_prime, x_j) = self.permute_data(data) # encode enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step( x_i_t, x_i_tk, x_i_prime, x_j) # decode x_tilde = self.decode_step(enc_i_t, c_i) loss_rec = torch.mean(torch.abs(x_tilde - x_i_t)) # latent discriminate loss_adv = self.latent_discriminate_step(enc_i_t, enc_i_tk, enc_i_prime, enc_j, is_dis=False) ae_loss = loss_rec + current_alpha * loss_adv reset_grad([self.Encoder, self.Decoder]) retain_graph = True if hps.n_patch_steps > 0 else False ae_loss.backward(retain_graph=retain_graph) grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm) self.ae_opt.step() info = { f'{flag}/loss_rec': loss_rec.data[0], f'{flag}/loss_adv': loss_adv.data[0], f'{flag}/alpha': current_alpha, } slot_value = (iteration + 1, hps.iters) + tuple( [value for value in info.values()]) log = 'G:[%06d/%06d], loss_rec=%.2f, loss_adv=%.2f, alpha=%.2e' print(log % slot_value) for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) # patch discriminate if hps.n_patch_steps > 0 and iteration >= hps.patch_start_iter: c_sample = self.sample_c(x_i_t.size(0)) x_tilde = self.decode_step(enc_i_t, c_sample) patch_w_dis, real_logits, fake_logits = \ self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=False) patch_loss = hps.beta_dec * patch_w_dis + hps.beta_clf * c_loss reset_grad([self.Decoder]) patch_loss.backward() grad_clip([self.Decoder], self.hps.max_grad_norm) self.decoder_opt.step() info = { f'{flag}/G_patch_w_dis': patch_w_dis.data[0], f'{flag}/c_loss': c_loss.data[0], f'{flag}/real_acc': real_acc, f'{flag}/fake_acc': fake_acc, } slot_value = (iteration + 1, hps.iters) + tuple( [value for value in info.values()]) log = 'G:[%06d/%06d]: patch_w_dis=%.2f, c_loss=%.2f, real_acc=%.2f, fake_acc=%.2f' print(log % slot_value) for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if iteration % 1000 == 0 or iteration + 1 == hps.iters: self.save_model(model_path, iteration)