def finetune_discrete_decoder(trainer, look_up, model_path, flag='train'): # trainer is already trained hyper_params = trainer.hps for iteration in range(hyper_params.enc_pretrain_iters): data = next(trainer.data_loader) c, x = trainer.permute_data(data) encoded = trainer.encode_step(x) x = look_up(x) x_tilde = trainer.decode_step(encoded, c) loss_rec = torch.mean(torch.abs(x_tilde - x)) reset_grad([trainer.Encoder, trainer.Decoder]) loss_rec.backward() grad_clip([trainer.Encoder, trainer.Decoder], trainer.hps.max_grad_norm) trainer.ae_opt.step() # tb info info = { f'{flag}/disc_loss_rec': loss_rec.item(), } slot_value = (iteration + 1, hyper_params.enc_pretrain_iters) + \ tuple([value for value in info.values()]) log = 'train_discrete:[%06d/%06d], loss_rec=%.3f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): trainer.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: trainer.save_model(model_path, 'dc', iteration + 1) print()
def train(self, model_path, flag='train'): # load hyperparams hps = self.hps for iteration in range(hps.iters): data = next(self.data_loader) y, x = self.permute_data(data) # encode enc = self.encode_step(x) # forward to classifier logits = self.forward_step(enc) # calculate loss loss = self.cal_loss(logits, y) # optimize reset_grad([self.SpeakerClassifier]) loss.backward() grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm) self.opt.step() # calculate acc acc = cal_acc(logits, y) # print info info = { f'{flag}/loss': loss.data[0], f'{flag}/acc': acc, } slot_value = (iteration + 1, hps.iters) + tuple([value for value in info.values()]) log = 'iter:[%06d/%06d], loss=%.3f, acc=%.3f' print(log % slot_value, end='\r') for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration) if iteration % 1000 == 0 or iteration + 1 == hps.iters: valid_loss, valid_acc = self.valid(n_batches=10) # print info info = { f'{flag}/valid_loss': valid_loss, f'{flag}/valid_acc': valid_acc, } slot_value = (iteration + 1, hps.iters) + \ tuple([value for value in info.values()]) log = 'iter:[%06d/%06d], valid_loss=%.3f, valid_acc=%.3f' print(log % slot_value) for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration) self.save_model(model_path, iteration)
x = Variable(torch.FloatTensor(np.transpose(o, (2, 0, 1))[None])) x_next = Variable( torch.FloatTensor(np.transpose(o_next, (2, 0, 1))[None])) z = C.encode(x) density_x = C.density(x_next, z) density_sum = 0 for j in [n for n in range(n_trajs) if n != i]: k = np.random.randint(len(data[j])) o_other = data[j][k][0] if torch.cuda.is_available(): x_other = Variable( torch.cuda.FloatTensor( np.transpose(o_other, (2, 0, 1))[None])) else: x_other = Variable( torch.FloatTensor(np.transpose(o_other, (2, 0, 1))[None])) density_sum += torch.exp(C.density(x_other, z) - density_x) density = 1.0 / (1.0 + density_sum) C_loss = -torch.mean(torch.log(density)) C_loss.backward() C_solver.step() reset_grad(params) print('********** Epoch %i ************' % epoch) print(C_loss) log_value('C_loss', C_loss, epoch) if not os.path.exists('%s/var' % savepath): os.makedirs('%s/var' % savepath) torch.save(C.state_dict(), '%s/var/cpc%d' % (savepath, epoch))
def train(self, model_path, flag='train', mode='train'): # load hyperparams hps = self.hps if mode == 'pretrain_G': for iteration in range(hps.enc_pretrain_iters): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) x_tilde = self.decode_step(enc, c) loss_rec = torch.mean(torch.abs(x_tilde - x)) reset_grad([self.Encoder, self.Decoder]) loss_rec.backward() grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm) self.ae_opt.step() # tb info info = { f'{flag}/pre_loss_rec': loss_rec.item(), } slot_value = (iteration + 1, hps.enc_pretrain_iters) + tuple([value for value in info.values()]) log = 'pre_G:[%06d/%06d], loss_rec=%.3f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) elif mode == 'pretrain_D': for iteration in range(hps.dis_pretrain_iters): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) # update reset_grad([self.SpeakerClassifier]) loss_clf.backward() grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = cal_acc(logits, c) info = { f'{flag}/pre_loss_clf': loss_clf.item(), f'{flag}/pre_acc': acc, } slot_value = (iteration + 1, hps.dis_pretrain_iters) + tuple([value for value in info.values()]) log = 'pre_D:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) elif mode == 'patchGAN': for iteration in range(hps.patch_iters): #=======train D=========# for step in range(hps.n_patch_steps): data = next(self.data_loader) c, x = self.permute_data(data) ## encode enc = self.encode_step(x) # sample c c_prime = self.sample_c(x.size(0)) # generator x_tilde = self.gen_step(enc, c_prime) # discriminstor w_dis, real_logits, gp = self.patch_step(x, x_tilde, is_dis=True) # aux classification loss loss_clf = self.cal_loss(real_logits, c) loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp reset_grad([self.PatchDiscriminator]) loss.backward() grad_clip([self.PatchDiscriminator], self.hps.max_grad_norm) self.patch_opt.step() # calculate acc acc = cal_acc(real_logits, c) info = { f'{flag}/w_dis': w_dis.item(), f'{flag}/gp': gp.item(), f'{flag}/real_loss_clf': loss_clf.item(), f'{flag}/real_acc': acc, } slot_value = (step, iteration+1, hps.patch_iters) + tuple([value for value in info.values()]) log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) #=======train G=========# data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # sample c c_prime = self.sample_c(x.size(0)) # generator x_tilde = self.gen_step(enc, c_prime) # discriminstor loss_adv, fake_logits = self.patch_step(x, x_tilde, is_dis=False) # aux classification loss loss_clf = self.cal_loss(fake_logits, c_prime) loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv reset_grad([self.Generator]) loss.backward() grad_clip([self.Generator], self.hps.max_grad_norm) self.gen_opt.step() # calculate acc acc = cal_acc(fake_logits, c_prime) info = { f'{flag}/loss_adv': loss_adv.item(), f'{flag}/fake_loss_clf': loss_clf.item(), f'{flag}/fake_acc': acc, } slot_value = (iteration+1, hps.patch_iters) + tuple([value for value in info.values()]) log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items():
if iteration < hps.lat_sched_iters: current_alpha = hps.alpha_enc * (iteration / hps.lat_sched_iters) else: current_alpha = hps.alpha_enc #==================train D==================# for step in range(hps.n_latent_steps): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) loss = hps.alpha_dis * loss_clf # update reset_grad([self.SpeakerClassifier]) loss.backward() grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = cal_acc(logits, c) info = { f'{flag}/D_loss_clf': loss_clf.item(), f'{flag}/D_acc': acc, } slot_value = (step, iteration + 1, hps.iters) + tuple([value for value in info.values()]) log = 'D-%d:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1)
def train(self, model_path, flag='train', mode='train'): if not os.path.isdir(model_path): os.makedirs(model_path) os.chmod(model_path, 0o755) model_path = os.path.join(model_path, 'model.pkl') # load hyperparams hps = self.hps if mode == 'pretrain_G': for iteration in range(2200): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) x_tilde = self.decode_step(enc, c) loss_rec = torch.mean(torch.abs(x_tilde - x)) reset_grad([self.Encoder, self.Decoder]) loss_rec.backward() grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm) self.ae_opt.step() # tb info info = { f'{flag}/pre_loss_rec': loss_rec.item(), } slot_value = (iteration + 1, 2200) + tuple( [value for value in info.values()]) log = 'pre_G:[%06d/%06d], loss_rec=%.3f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) elif mode == 'pretrain_D': for iteration in range(2200): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) # update reset_grad([self.SpeakerClassifier]) loss_clf.backward() grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = cal_acc(logits, c) info = { f'{flag}/pre_loss_clf': loss_clf.item(), f'{flag}/pre_acc': acc, } slot_value = (iteration + 1, 2200) + tuple( [value for value in info.values()]) log = 'pre_D:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) elif mode == 'patchGAN': for iteration in range(1100): #=======train D=========# for step in range(hps.n_patch_steps): data = next(self.data_loader) c, x = self.permute_data(data) ## encode enc = self.encode_step(x) # sample c c_prime = self.sample_c(x.size(0)) # generator x_tilde = self.gen_step(enc, c_prime) # discriminstor w_dis, real_logits, gp = self.patch_step(x, x_tilde, is_dis=True) # aux classification loss loss_clf = self.cal_loss(real_logits, c) loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp reset_grad([self.PatchDiscriminator]) loss.backward() grad_clip([self.PatchDiscriminator], self.hps.max_grad_norm) self.patch_opt.step() # calculate acc acc = cal_acc(real_logits, c) info = { f'{flag}/w_dis': w_dis.item(), f'{flag}/gp': gp.item(), f'{flag}/real_loss_clf': loss_clf.item(), f'{flag}/real_acc': acc, } slot_value = (step, iteration + 1, 1100) + tuple( [value for value in info.values()]) log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) #=======train G=========# data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # sample c c_prime = self.sample_c(x.size(0)) # generator x_tilde = self.gen_step(enc, c_prime) # discriminstor loss_adv, fake_logits = self.patch_step(x, x_tilde, is_dis=False) # aux classification loss loss_clf = self.cal_loss(fake_logits, c_prime) loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv reset_grad([self.Generator]) loss.backward() grad_clip([self.Generator], self.hps.max_grad_norm) self.gen_opt.step() # calculate acc acc = cal_acc(fake_logits, c_prime) info = { f'{flag}/loss_adv': loss_adv.item(), f'{flag}/fake_loss_clf': loss_clf.item(), f'{flag}/fake_acc': acc, } slot_value = (iteration + 1, 1100) + tuple( [value for value in info.values()]) log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if iteration % 1000 == 0 or iteration + 1 == hps.patch_iters: self.save_model(model_path, iteration + hps.iters) elif mode == 'train': for iteration in range(1100): # calculate current alpha if iteration < hps.lat_sched_iters: current_alpha = hps.alpha_enc * (iteration / hps.lat_sched_iters) else: current_alpha = hps.alpha_enc #==================train D==================# for step in range(hps.n_latent_steps): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) loss = hps.alpha_dis * loss_clf # update reset_grad([self.SpeakerClassifier]) loss.backward() grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = cal_acc(logits, c) info = { f'{flag}/D_loss_clf': loss_clf.item(), f'{flag}/D_acc': acc, } slot_value = (step, iteration + 1, 1100) + tuple( [value for value in info.values()]) log = 'D-%d:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) #==================train G==================# data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # decode x_tilde = self.decode_step(enc, c) loss_rec = torch.mean(torch.abs(x_tilde - x)) # classify speaker logits = self.clf_step(enc) acc = cal_acc(logits, c) loss_clf = self.cal_loss(logits, c) # maximize classification loss loss = loss_rec - current_alpha * loss_clf reset_grad([self.Encoder, self.Decoder]) loss.backward() grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm) self.ae_opt.step() info = { f'{flag}/loss_rec': loss_rec.item(), f'{flag}/G_loss_clf': loss_clf.item(), f'{flag}/alpha': current_alpha, f'{flag}/G_acc': acc, } slot_value = (iteration + 1, 1100) + tuple( [value for value in info.values()]) log = 'G:[%06d/%06d], loss_rec=%.3f, loss_clf=%.2f, alpha=%.2e, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if iteration % 1000 == 0 or iteration + 1 == hps.iters: self.save_model(model_path, iteration)
def train(self, model_path, flag='train', mode='train'): # load hyperparams hps = self.hps if mode == 'pretrain_G': for iteration in range(hps.enc_pretrain_iters): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) x_tilde = self.decode_step(enc, c) loss_rec = torch.mean(torch.abs(x_tilde - x)) reset_grad([self.Encoder, self.Decoder]) loss_rec.backward() grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm) self.ae_opt.step() # tb info info = { f'{flag}/pre_loss_rec': loss_rec.item(), } slot_value = (iteration + 1, hps.enc_pretrain_iters) + tuple( [value for value in info.values()]) log = 'pre_G:[%06d/%06d], loss_rec=%.3f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) elif mode == 'pretrain_D': for iteration in range(hps.dis_pretrain_iters): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) # update reset_grad([self.SpeakerClassifier]) loss_clf.backward() grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = cal_acc(logits, c) info = { f'{flag}/pre_loss_clf': loss_clf.item(), f'{flag}/pre_acc': acc, } slot_value = (iteration + 1, hps.dis_pretrain_iters) + tuple( [value for value in info.values()]) log = 'pre_D:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) elif mode == 'patchGAN': for iteration in range(hps.patch_iters): #=======train D=========# for step in range(hps.n_patch_steps): data = next(self.data_loader) c, x = self.permute_data(data) ## encode enc = self.encode_step(x) # sample c c_prime = self.sample_c(x.size(0)) # generator x_tilde = self.gen_step(enc, c_prime) # discriminstor w_dis, real_logits, gp = self.patch_step(x, x_tilde, is_dis=True) # aux classification loss loss_clf = self.cal_loss(real_logits, c) loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp reset_grad([self.PatchDiscriminator]) loss.backward() grad_clip([self.PatchDiscriminator], self.hps.max_grad_norm) self.patch_opt.step() # calculate acc acc = cal_acc(real_logits, c) info = { f'{flag}/w_dis': w_dis.item(), f'{flag}/gp': gp.item(), f'{flag}/real_loss_clf': loss_clf.item(), f'{flag}/real_acc': acc, } slot_value = (step, iteration + 1, hps.patch_iters) + tuple( [value for value in info.values()]) log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) #=======train G=========# data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # sample c c_prime = self.sample_c(x.size(0)) # generator x_tilde = self.gen_step(enc, c_prime) # discriminstor loss_adv, fake_logits = self.patch_step(x, x_tilde, is_dis=False) # aux classification loss loss_clf = self.cal_loss(fake_logits, c_prime) loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv reset_grad([self.Generator]) loss.backward() grad_clip([self.Generator], self.hps.max_grad_norm) self.gen_opt.step() # calculate acc acc = cal_acc(fake_logits, c_prime) info = { f'{flag}/loss_adv': loss_adv.item(), f'{flag}/fake_loss_clf': loss_clf.item(), f'{flag}/fake_acc': acc, } slot_value = (iteration + 1, hps.patch_iters) + tuple( [value for value in info.values()]) log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration) #===================== Train G =====================# data = next(self.data_loader) (c_i, c_j), (x_i_t, x_i_tk, x_i_prime, x_j) = self.permute_data(data) # encode enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step( x_i_t, x_i_tk, x_i_prime, x_j) # decode x_tilde = self.decode_step(enc_i_t, c_i) loss_rec = torch.mean(torch.abs(x_tilde - x_i_t)) # latent discriminate loss_adv = self.latent_discriminate_step(enc_i_t, enc_i_tk, enc_i_prime, enc_j, is_dis=False) ae_loss = loss_rec + current_alpha * loss_adv reset_grad([self.Encoder, self.Decoder]) retain_graph = True if hps.n_patch_steps > 0 else False ae_loss.backward(retain_graph=retain_graph) grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm) self.ae_opt.step() info = { f'{flag}/loss_rec': loss_rec.data[0], f'{flag}/loss_adv': loss_adv.data[0], f'{flag}/alpha': current_alpha, } slot_value = (iteration + 1, hps.iters) + tuple( [value for value in info.values()]) log = 'G:[%06d/%06d], loss_rec=%.2f, loss_adv=%.2f, alpha=%.2e' print(log % slot_value) for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) # patch discriminate if hps.n_patch_steps > 0 and iteration >= hps.patch_start_iter: c_sample = self.sample_c(x_i_t.size(0)) x_tilde = self.decode_step(enc_i_t, c_sample) patch_w_dis, real_logits, fake_logits = \ self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=False) patch_loss = hps.beta_dec * patch_w_dis + hps.beta_clf * c_loss reset_grad([self.Decoder]) patch_loss.backward() grad_clip([self.Decoder], self.hps.max_grad_norm) self.decoder_opt.step() info = { f'{flag}/loss_rec': loss_rec.item(), f'{flag}/G_loss_clf': loss_clf.item(), f'{flag}/alpha': current_alpha, f'{flag}/G_acc': acc, } slot_value = (iteration + 1, hps.iters) + tuple( [value for value in info.values()]) log = 'G:[%06d/%06d], loss_rec=%.3f, loss_clf=%.2f, alpha=%.2e, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if iteration % 1000 == 0 or iteration + 1 == hps.iters: self.save_model(model_path, iteration)
def train(self, model_path, flag='train', mode='train', target_guided=False): # load hyperparams hps = self.hps if mode == 'pretrain_AE': for iteration in range(hps.enc_pretrain_iters): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc_act, enc = self.encode_step(x) x_dec = self.decode_step(enc_act, c) loss_rec = torch.mean(torch.abs(x_dec - x)) reset_grad([self.Encoder, self.Decoder]) loss_rec.backward() grad_clip([self.Encoder, self.Decoder], hps.max_grad_norm) self.ae_opt.step() # tb info info = { f'{flag}/pre_loss_rec': loss_rec.item(), } slot_value = (iteration + 1, hps.enc_pretrain_iters) + tuple( [value for value in info.values()]) log = 'pre_AE:[%06d/%06d], loss_rec=%.3f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 'ae', iteration + 1) print() elif mode == 'pretrain_C': for iteration in range(hps.dis_pretrain_iters): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc_act, enc = self.encode_step(x) # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) # update reset_grad([self.SpeakerClassifier]) loss_clf.backward() grad_clip([self.SpeakerClassifier], hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = self.cal_acc(logits, c) info = { f'{flag}/pre_loss_clf': loss_clf.item(), f'{flag}/pre_acc': acc, } slot_value = (iteration + 1, hps.dis_pretrain_iters) + tuple( [value for value in info.values()]) log = 'pre_C:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 'c', iteration + 1) print() elif mode == 'train': for iteration in range(hps.iters): # calculate current alpha if iteration < hps.lat_sched_iters: current_alpha = hps.alpha_enc * (iteration / hps.lat_sched_iters) else: current_alpha = hps.alpha_enc #==================train D==================# for step in range(hps.n_latent_steps): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc_act, enc = self.encode_step(x) # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) loss = hps.alpha_dis * loss_clf # update reset_grad([self.SpeakerClassifier]) loss.backward() grad_clip([self.SpeakerClassifier], hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = self.cal_acc(logits, c) info = { f'{flag}/D_loss_clf': loss_clf.item(), f'{flag}/D_acc': acc, } slot_value = (step, iteration + 1, hps.iters) + tuple( [value for value in info.values()]) log = 'D-%d:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) #==================train G==================# data = next(self.data_loader) c, x = self.permute_data(data) # encode enc_act, enc = self.encode_step(x) # decode x_dec = self.decode_step(enc_act, c) loss_rec = torch.mean(torch.abs(x_dec - x)) # classify speaker logits = self.clf_step(enc) acc = self.cal_acc(logits, c) loss_clf = self.cal_loss(logits, c) # maximize classification loss loss = loss_rec - current_alpha * loss_clf reset_grad([self.Encoder, self.Decoder]) loss.backward() grad_clip([self.Encoder, self.Decoder], hps.max_grad_norm) self.ae_opt.step() info = { f'{flag}/loss_rec': loss_rec.item(), f'{flag}/G_loss_clf': loss_clf.item(), f'{flag}/alpha': current_alpha, f'{flag}/G_acc': acc, } slot_value = (iteration + 1, hps.iters) + tuple( [value for value in info.values()]) log = 'G:[%06d/%06d], loss_rec=%.3f, loss_clf=%.2f, alpha=%.2e, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 's1', iteration + 1) print() elif mode == 'patchGAN': for iteration in range(hps.patch_iters): #==================train D==================# for step in range(hps.n_patch_steps): data_s = next(self.source_loader) data_t = next(self.target_loader) _, x_s = self.permute_data(data_s) c_t, x_t = self.permute_data(data_t) # encode enc_act, _ = self.encode_step(x_s) # generator x_dec = self.gen_step(enc_act, c_t) # discriminstor w_dis, real_logits, gp = self.patch_step(x_t, x_dec, is_dis=True) # aux classification loss loss_clf = self.cal_loss(real_logits, c_t, shift=True) loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp reset_grad([self.PatchDiscriminator]) loss.backward() grad_clip([self.PatchDiscriminator], hps.max_grad_norm) self.patch_opt.step() # calculate acc acc = self.cal_acc(real_logits, c_t, shift=True) info = { f'{flag}/w_dis': w_dis.item(), f'{flag}/gp': gp.item(), f'{flag}/real_loss_clf': loss_clf.item(), f'{flag}/real_acc': acc, } slot_value = (step, iteration + 1, hps.patch_iters) + tuple( [value for value in info.values()]) log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) #==================train G==================# data_s = next(self.source_loader) data_t = next(self.target_loader) _, x_s = self.permute_data(data_s) c_t, x_t = self.permute_data(data_t) # encode enc_act, _ = self.encode_step(x_s) # generator x_dec = self.gen_step(enc_act, c_t) # discriminstor loss_adv, fake_logits = self.patch_step(x_t, x_dec, is_dis=False) # aux classification loss loss_clf = self.cal_loss(fake_logits, c_t, shift=True) loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv reset_grad([self.Generator]) loss.backward() grad_clip([self.Generator], hps.max_grad_norm) self.gen_opt.step() if target_guided: # teacher forcing enc_tf, _ = self.encode_step(x_t) x_dec_tf = self.gen_step(enc_tf, c_t) loss_rec = torch.mean(torch.abs(x_dec_tf - x_t)) reset_grad([self.Generator]) loss_rec.backward() self.gen_opt.step() # calculate acc acc = self.cal_acc(fake_logits, c_t, shift=True) info = { f'{flag}/loss_adv': loss_adv.item(), f'{flag}/fake_loss_clf': loss_clf.item(), f'{flag}/fake_acc': acc, f'{flag}/tg_rec': loss_rec.item() if target_guided else 0.000, } slot_value = (iteration + 1, hps.patch_iters) + tuple( [value for value in info.values()]) log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f, tg_rec=%.3f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 's2', iteration + 1) print() elif mode == 'autolocker': criterion = torch.nn.BCELoss() for iteration in range(hps.patch_iters): #==================train G==================# data_s = next(self.source_loader) data_t = next(self.target_loader) _, x_s = self.permute_data(data_s) c_t, x_t = self.permute_data(data_t) # encode enc_act, _ = self.encode_step(x_s) # decode residual_output = self.gen_step(enc_act, c_t) # re-encode re_enc, _ = self.encode_step(residual_output) # re-encode loss loss_reenc = criterion(re_enc, enc_act.data) reset_grad([self.Encoder, self.Decoder, self.Generator]) loss_reenc.backward() grad_clip([self.Generator], hps.max_grad_norm) self.gen_opt.step() if target_guided: # teacher forcing enc_tf, _ = self.encode_step(x_t) x_dec_tf = self.gen_step(enc_tf, c_t) loss_rec = torch.mean(torch.abs(x_dec_tf - x_t)) reset_grad([self.Encoder, self.Decoder, self.Generator]) loss_rec.backward() self.gen_opt.step() # calculate acc info = { f'{flag}/re_enc': loss_reenc.item(), f'{flag}/tg_rec': loss_rec.item() if target_guided else 0.000, } slot_value = (iteration + 1, hps.patch_iters) + tuple( [value for value in info.values()]) log = 'patch_G:[%06d/%06d], re_enc=%.3f, tg_rec=%.3f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 's2', iteration + 1) print() elif mode == 't_classify': for iteration in range(hps.tclf_iters): #======train target classifier======# data = next(self.data_loader) c, x = self.permute_data(data) c[c < 100] = 102 # classification logits = self.tclf_step(x) # classification loss loss = self.cal_loss(logits, c - self.shift_c) reset_grad([self.TargetClassifier]) loss.backward() grad_clip([self.TargetClassifier], hps.max_grad_norm) self.tclf_opt.step() # calculate acc acc = self.cal_acc(logits, c - self.shift_c) info = { f'{flag}/acc': acc, } slot_value = (iteration + 1, hps.tclf_iters) + tuple( [value for value in info.values()]) log = 'Target Classifier:[%05d/%05d], acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 'tclf', iteration + 1) print() elif mode == 'train_Tacotron': assert self.g_mode == 'tacotron' criterion = TacotronLoss() self.Encoder.eval() for iteration in range(hps.tacotron_iters): #======train tacotron======# cur_lr = learning_rate_decay(init_lr=0.002, global_step=iteration) for param_group in self.gen_opt.param_groups: param_group['lr'] = cur_lr data = next(self.data_loader) c, x, m = self.permute_data(data, load_mel=True) # encode enc_act, enc = self.encode_step(x) # tacotron synthesis m_dec, x_dec = self.tacotron_step(enc_act.data, m, c) # reconstruction loss loss_rec = criterion([m_dec, x_dec], [m, x]) reset_grad([self.Generator]) loss_rec.backward() grad_clip([self.Generator], hps.max_grad_norm) self.gen_opt.step() # tb info info = { f'{flag}/tacotron_loss_rec': loss_rec.item(), f'{flag}/tacotron_lr': cur_lr, } slot_value = (iteration + 1, hps.tacotron_iters) + tuple( [value for value in info.values()]) log = 'train_Tacotron:[%06d/%06d], loss_rec=%.3f, lr=%.2e' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 't', iteration + 1) print() else: raise NotImplementedError()
def train(self, model_path, flag='train', mode='train'): # load hyperparams hps = self.hps if mode == 'pretrain_AE': for iteration in range(hps.enc_pretrain_iters): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) x_tilde = self.decode_step(enc, c) loss_rec = torch.mean(torch.abs(x_tilde - x)) reset_grad([self.Encoder, self.Decoder]) loss_rec.backward() grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm) self.ae_opt.step() # tb info info = { f'{flag}/pre_loss_rec': loss_rec.item(), } slot_value = (iteration + 1, hps.enc_pretrain_iters) + \ tuple([value for value in info.values()]) log = 'pre_AE:[%06d/%06d], loss_rec=%.3f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 'ae', iteration + 1) print() elif mode == 'pretrain_C': for iteration in range(hps.dis_pretrain_iters): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) # update reset_grad([self.SpeakerClassifier]) loss_clf.backward() grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = self.cal_acc(logits, c) info = { f'{flag}/pre_loss_clf': loss_clf.item(), f'{flag}/pre_acc': acc, } slot_value = (iteration + 1, hps.dis_pretrain_iters) + \ tuple([value for value in info.values()]) log = 'pre_C:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 'c', iteration + 1) print() elif mode == 'train': for iteration in range(hps.iters): # calculate current alpha if iteration < hps.lat_sched_iters: current_alpha = hps.alpha_enc * \ (iteration / hps.lat_sched_iters) else: current_alpha = hps.alpha_enc #==================train D==================# for step in range(hps.n_latent_steps): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) _, z_mean, z_log_var = enc kl_loss = (1 + z_log_var - z_mean**2 - torch.exp(z_log_var)).sum(-1) * -.5 kl_loss = kl_loss.sum() # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) loss = hps.alpha_dis * loss_clf + kl_loss # update reset_grad([self.SpeakerClassifier]) loss.backward() grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = self.cal_acc(logits, c) info = { f'{flag}/D_loss_clf': loss_clf.item(), f'{flag}/D_acc': acc, } slot_value = (step, iteration + 1, hps.iters) + \ tuple([value for value in info.values()]) log = 'D-%d:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) #==================train G==================# data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # decode x_tilde = self.decode_step(enc, c) loss_rec = torch.mean(torch.abs(x_tilde - x)) # classify speaker logits = self.clf_step(enc) acc = self.cal_acc(logits, c) loss_clf = self.cal_loss(logits, c) # maximize classification loss loss = loss_rec - current_alpha * loss_clf reset_grad([self.Encoder, self.Decoder]) loss.backward() grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm) self.ae_opt.step() info = { f'{flag}/loss_rec': loss_rec.item(), f'{flag}/G_loss_clf': loss_clf.item(), f'{flag}/alpha': current_alpha, f'{flag}/G_acc': acc, } slot_value = (iteration + 1, hps.iters) + \ tuple([value for value in info.values()]) log = 'G:[%06d/%06d], loss_rec=%.3f, loss_clf=%.2f, alpha=%.2e, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 's1', iteration + 1) print() elif mode == 'patchGAN': for iteration in range(hps.patch_iters): #==================train D==================# for step in range(hps.n_patch_steps): data_s = next(self.source_loader) data_t = next(self.target_loader) _, x_s = self.permute_data(data_s) c, x_t = self.permute_data(data_t) # encode enc = self.encode_step(x_s) # sample c c_prime = self.sample_c(x_t.size(0)) # generator x_tilde = self.gen_step(enc, c_prime) # discriminstor w_dis, real_logits, gp = self.patch_step(x_t, x_tilde, is_dis=True) # aux classification loss loss_clf = self.cal_loss(real_logits, c, shift=True) loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp reset_grad([self.PatchDiscriminator]) loss.backward() grad_clip([self.PatchDiscriminator], self.hps.max_grad_norm) self.patch_opt.step() # calculate acc acc = self.cal_acc(real_logits, c, shift=True) info = { f'{flag}/w_dis': w_dis.item(), f'{flag}/gp': gp.item(), f'{flag}/real_loss_clf': loss_clf.item(), f'{flag}/real_acc': acc, } slot_value = (step, iteration + 1, hps.patch_iters) + \ tuple([value for value in info.values()]) log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) #==================train G==================# data_s = next(self.source_loader) data_t = next(self.target_loader) _, x_s = self.permute_data(data_s) c, x_t = self.permute_data(data_t) # encode enc = self.encode_step(x_s) # sample c c_prime = self.sample_c(x_t.size(0)) # generator x_tilde = self.gen_step(enc, c_prime) # discriminstor loss_adv, fake_logits = self.patch_step(x_t, x_tilde, is_dis=False) # aux classification loss loss_clf = self.cal_loss(fake_logits, c_prime, shift=True) loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv reset_grad([self.Generator]) loss.backward() grad_clip([self.Generator], self.hps.max_grad_norm) self.gen_opt.step() # calculate acc acc = self.cal_acc(fake_logits, c_prime, shift=True) info = { f'{flag}/loss_adv': loss_adv.item(), f'{flag}/fake_loss_clf': loss_clf.item(), f'{flag}/fake_acc': acc, } slot_value = (iteration + 1, hps.patch_iters) + \ tuple([value for value in info.values()]) log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 's2', iteration + 1 + hps.iters) print() else: raise NotImplementedError()
def train(self, model_path, flag='train'): # load hyperparams hps = self.hps for iteration in range(hps.iters): # calculate current alpha if iteration + 1 < hps.lat_sched_iters and iteration >= hps.enc_pretrain_iters: current_alpha = hps.alpha_enc * ( iteration + 1 - hps.enc_pretrain_iters) / ( hps.lat_sched_iters - hps.enc_pretrain_iters) else: current_alpha = 0 if iteration >= hps.enc_pretrain_iters: n_latent_steps = hps.n_latent_steps \ if iteration > hps.enc_pretrain_iters else hps.dis_pretrain_iters for step in range(n_latent_steps): #===================== Train latent discriminator =====================# data = next(self.data_loader) (c_i, c_j), (x_i_t, x_i_tk, x_i_prime, x_j) = self.permute_data(data) # encode enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step( x_i_t, x_i_tk, x_i_prime, x_j) # latent discriminate latent_w_dis, latent_gp = self.latent_discriminate_step( enc_i_t, enc_i_tk, enc_i_prime, enc_j) lat_loss = -hps.alpha_dis * latent_w_dis + hps.lambda_ * latent_gp reset_grad([self.LatentDiscriminator]) lat_loss.backward() grad_clip([self.LatentDiscriminator], self.hps.max_grad_norm) self.lat_opt.step() # print info info = { f'{flag}/D_latent_w_dis': latent_w_dis.data[0], f'{flag}/latent_gp': latent_gp.data[0], } slot_value = (step, iteration + 1, hps.iters) + \ tuple([value for value in info.values()]) log = 'lat_D-%d:[%06d/%06d], w_dis=%.3f, gp=%.2f' print(log % slot_value) for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration) # two stage training if iteration >= hps.patch_start_iter: for step in range(hps.n_patch_steps): #===================== Train patch discriminator =====================# data = next(self.data_loader) (c_i, _), (x_i_t, _, _, _) = self.permute_data(data) # encode enc_i_t, = self.encode_step(x_i_t) c_sample = self.sample_c(x_i_t.size(0)) x_tilde = self.decode_step(enc_i_t, c_i) # Aux classify loss patch_w_dis, real_logits, fake_logits, patch_gp = \ self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=True) patch_loss = -hps.beta_dis * patch_w_dis + hps.lambda_ * patch_gp + hps.beta_clf * c_loss reset_grad([self.PatchDiscriminator]) patch_loss.backward() grad_clip([self.PatchDiscriminator], self.hps.max_grad_norm) self.patch_opt.step() # print info info = { f'{flag}/D_patch_w_dis': patch_w_dis.data[0], f'{flag}/patch_gp': patch_gp.data[0], f'{flag}/c_loss': c_loss.data[0], f'{flag}/real_acc': real_acc, f'{flag}/fake_acc': fake_acc, } slot_value = (step, iteration + 1, hps.iters) + \ tuple([value for value in info.values()]) log = 'patch_D-%d:[%06d/%06d], w_dis=%.3f, gp=%.2f, c_loss=%.3f, real_acc=%.2f, fake_acc=%.2f' print(log % slot_value) for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration) #===================== Train G =====================# data = next(self.data_loader) (c_i, c_j), (x_i_t, x_i_tk, x_i_prime, x_j) = self.permute_data(data) # encode enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step( x_i_t, x_i_tk, x_i_prime, x_j) # decode x_tilde = self.decode_step(enc_i_t, c_i) loss_rec = torch.mean(torch.abs(x_tilde - x_i_t)) # latent discriminate loss_adv = self.latent_discriminate_step(enc_i_t, enc_i_tk, enc_i_prime, enc_j, is_dis=False) ae_loss = loss_rec + current_alpha * loss_adv reset_grad([self.Encoder, self.Decoder]) retain_graph = True if hps.n_patch_steps > 0 else False ae_loss.backward(retain_graph=retain_graph) grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm) self.ae_opt.step() info = { f'{flag}/loss_rec': loss_rec.data[0], f'{flag}/loss_adv': loss_adv.data[0], f'{flag}/alpha': current_alpha, } slot_value = (iteration + 1, hps.iters) + tuple( [value for value in info.values()]) log = 'G:[%06d/%06d], loss_rec=%.2f, loss_adv=%.2f, alpha=%.2e' print(log % slot_value) for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) # patch discriminate if hps.n_patch_steps > 0 and iteration >= hps.patch_start_iter: c_sample = self.sample_c(x_i_t.size(0)) x_tilde = self.decode_step(enc_i_t, c_sample) patch_w_dis, real_logits, fake_logits = \ self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=False) patch_loss = hps.beta_dec * patch_w_dis + hps.beta_clf * c_loss reset_grad([self.Decoder]) patch_loss.backward() grad_clip([self.Decoder], self.hps.max_grad_norm) self.decoder_opt.step() info = { f'{flag}/G_patch_w_dis': patch_w_dis.data[0], f'{flag}/c_loss': c_loss.data[0], f'{flag}/real_acc': real_acc, f'{flag}/fake_acc': fake_acc, } slot_value = (iteration + 1, hps.iters) + tuple( [value for value in info.values()]) log = 'G:[%06d/%06d]: patch_w_dis=%.2f, c_loss=%.2f, real_acc=%.2f, fake_acc=%.2f' print(log % slot_value) for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if iteration % 1000 == 0 or iteration + 1 == hps.iters: self.save_model(model_path, iteration)
def train(): # tell PyTorch to use training mode (dropout, batch norm, etc) message_model.train() if args.user_classifier != 'none': user_model.train() if args.message_classifier == 'rnn': hidden = message_model.init_hidden(args.batch_size) iter = int(args.num_training / args.batch_size) for indx in range(iter): y_mssg = [] y_usr = [] x_message = Variable(get_batch(train_message,indx)) if args.user_classifier != 'none': x_user = Variable(get_batch(train_user,indx)) if args.cuda: x_message = x_message.cuda() if args.user_classifier != 'none': x_user = x_user.cuda() if (x_message.size())[0] == args.batch_size: ## Update message learner parameters if args.message_classifier == 'rnn': hidden1 = repackage_hidden(hidden) y_msg_rnn, hidden = message_model(x_message.t(), hidden1) y_mssg.append(y_msg_rnn) elif args.message_classifier == 'emb': y_msg_emd = message_model(x_message.t()) y_mssg.append(y_msg_emd) else: y_msg = message_model(x_message) y_mssg.append(y_msg) if args.user_classifier == 'emb': y_user = user_model(x_user.t()) y_usr = [y_user] elif args.user_classifier == 'node2vec': y_user = user_model(x_user) y_usr = [y_user] lb = torch.FloatTensor(get_batch(LB_train,indx)) ub = torch.FloatTensor(get_batch(UB_train, indx)) loss, _, _ = utils.cross_entropy_criterion(y_mssg, y_usr, lb, ub, args.cuda) loss.backward(retain_variables=True) message_optimizer.step() utils.reset_grad(message_model.parameters()) if args.user_classifier != 'none': utils.reset_grad(user_model.parameters()) ## Update user learner parameters if there is user learner if args.user_classifier != 'none': if args.message_classifier == 'doc2vec' or args.message_classifier == 'bow': y_msg = message_model(x_message) y_mssg = [y_msg] elif args.message_classifier == 'rnn': hidden1 = repackage_hidden(hidden) y_msg, hidden = message_model(x_message.t(), hidden1) y_mssg = [y_msg] elif args.message_classifier == 'emb': y_msg = message_model(x_message.t()) y_mssg = [y_msg] if args.user_classifier == 'emb': y_user = user_model(x_user.t()) y_usr = [y_user] else: y_user = user_model(x_user) y_usr = [y_user] loss, _, _ = utils.cross_entropy_criterion(y_mssg, y_usr, lb, ub, args.cuda) loss.backward() user_optimizer.step() utils.reset_grad(message_model.parameters()) utils.reset_grad(user_model.parameters())
def train(all_models, training_models, solver, training_params, log_every, **kwargs): model, c_model, actor = all_models k_steps = kwargs["k"] num_epochs = kwargs["n_epochs"] batch_size = kwargs["batch_size"] N = kwargs["N"] c_type = kwargs["c_type"] vae_weight = kwargs["vae_w"] beta = kwargs["vae_b"] # Configure experiment path savepath = kwargs['savepath'] conditional = kwargs["conditional"] configure('%s/var_log' % savepath, flush_secs=5) ### Load data ### -- assuming appropriate npy format data_file = kwargs["data_dir"] data = np.load(data_file) n_trajs = len(data) data_size = sum([len(data[i]) - k_steps for i in range(n_trajs)]) print('Number of trajectories: %d' % n_trajs) # 315 print('Number of transitions: %d' % data_size) # 378315 test_file = kwargs["test_dir"] test_data = np.load(test_file) test_context = get_torch_images_from_numpy(test_data, conditional, one_image=True) ### Train models ### c_loss = vae_loss = a_loss = torch.Tensor([0]).cuda() for epoch in range(num_epochs): n_batch = int(data_size / batch_size) print('********** Epoch %i ************' % epoch) for it in range(n_batch): idx, t = get_idx_t(batch_size, k_steps, n_trajs, data) o, c = get_torch_images_from_numpy(data[idx, t], conditional) ks = np.random.choice(k_steps, batch_size) o_next, _ = get_torch_images_from_numpy(data[idx, t + ks], conditional) o_neg = get_negative_examples( data, idx, batch_size, N, conditional) if kwargs["use_o_neg"] else None o_pred, mu, logvar, cond_info = model(o, c) o_next_pred, _, _, _ = model(o_next, c) # VAE loss if model in training_models: vae_loss = loss_function(o_pred, o, mu, logvar, cond_info.get("means_cond", None), cond_info.get("log_var_cond", None), beta=beta) * vae_weight vae_loss.backward() # C loss if c_model in training_models and epoch >= kwargs["pretrain"]: c_loss = get_c_loss(model, c_model, c_type, o_pred, o_next_pred, c, N, o_neg) c_loss.backward() # Actor loss if actor in training_models and epoch >= kwargs["pretrain"]: a = get_torch_actions(data[idx, t + 1]) a_loss = actor.loss(a, o, o_next, c) a_loss.backward() ### Update models ### if solver is not None: solver.step() reset_grad(training_params) if it % log_every == 0: ### Log info ### log_info(c_loss, vae_loss, a_loss, model, conditional, cond_info, it, n_batch, epoch) ### Save params ### if not os.path.exists('%s/var' % savepath): os.makedirs('%s/var' % savepath) torch.save(model.state_dict(), '%s/var/vae-%d-last-5' % (savepath, epoch % 5 + 1)) torch.save(c_model.state_dict(), '%s/var/cpc-%d-last-5' % (savepath, epoch % 5 + 1)) torch.save( actor.state_dict(), '%s/var/actor-%d-last-5' % (savepath, epoch % 5 + 1)) ### Log images ### with torch.no_grad(): n_contexts = 7 n_samples_per_c = 8 o_distinct_c = get_negative_examples( data, idx[:n_contexts], n_contexts, n_samples_per_c, conditional) log_images( o[:n_contexts], o_pred[:n_contexts], o_distinct_c.reshape(n_samples_per_c, n_contexts, *o_distinct_c.size()[1:]), c[:n_contexts], test_context, model, c_model, n_contexts, n_samples_per_c, savepath, epoch)