def finetune_discrete_decoder(trainer, look_up, model_path, flag='train'): # trainer is already trained hyper_params = trainer.hps for iteration in range(hyper_params.enc_pretrain_iters): data = next(trainer.data_loader) c, x = trainer.permute_data(data) encoded = trainer.encode_step(x) x = look_up(x) x_tilde = trainer.decode_step(encoded, c) loss_rec = torch.mean(torch.abs(x_tilde - x)) reset_grad([trainer.Encoder, trainer.Decoder]) loss_rec.backward() grad_clip([trainer.Encoder, trainer.Decoder], trainer.hps.max_grad_norm) trainer.ae_opt.step() # tb info info = { f'{flag}/disc_loss_rec': loss_rec.item(), } slot_value = (iteration + 1, hyper_params.enc_pretrain_iters) + \ tuple([value for value in info.values()]) log = 'train_discrete:[%06d/%06d], loss_rec=%.3f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): trainer.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: trainer.save_model(model_path, 'dc', iteration + 1) print()
def train(self, model_path, flag='train'): # load hyperparams hps = self.hps for iteration in range(hps.iters): data = next(self.data_loader) y, x = self.permute_data(data) # encode enc = self.encode_step(x) # forward to classifier logits = self.forward_step(enc) # calculate loss loss = self.cal_loss(logits, y) # optimize reset_grad([self.SpeakerClassifier]) loss.backward() grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm) self.opt.step() # calculate acc acc = cal_acc(logits, y) # print info info = { f'{flag}/loss': loss.data[0], f'{flag}/acc': acc, } slot_value = (iteration + 1, hps.iters) + tuple([value for value in info.values()]) log = 'iter:[%06d/%06d], loss=%.3f, acc=%.3f' print(log % slot_value, end='\r') for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration) if iteration % 1000 == 0 or iteration + 1 == hps.iters: valid_loss, valid_acc = self.valid(n_batches=10) # print info info = { f'{flag}/valid_loss': valid_loss, f'{flag}/valid_acc': valid_acc, } slot_value = (iteration + 1, hps.iters) + \ tuple([value for value in info.values()]) log = 'iter:[%06d/%06d], valid_loss=%.3f, valid_acc=%.3f' print(log % slot_value) for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration) self.save_model(model_path, iteration)
def train(self, model_path, flag='train', mode='train'): # load hyperparams hps = self.hps if mode == 'pretrain_G': for iteration in range(hps.enc_pretrain_iters): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) x_tilde = self.decode_step(enc, c) loss_rec = torch.mean(torch.abs(x_tilde - x)) reset_grad([self.Encoder, self.Decoder]) loss_rec.backward() grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm) self.ae_opt.step() # tb info info = { f'{flag}/pre_loss_rec': loss_rec.item(), } slot_value = (iteration + 1, hps.enc_pretrain_iters) + tuple([value for value in info.values()]) log = 'pre_G:[%06d/%06d], loss_rec=%.3f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) elif mode == 'pretrain_D': for iteration in range(hps.dis_pretrain_iters): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) # update reset_grad([self.SpeakerClassifier]) loss_clf.backward() grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = cal_acc(logits, c) info = { f'{flag}/pre_loss_clf': loss_clf.item(), f'{flag}/pre_acc': acc, } slot_value = (iteration + 1, hps.dis_pretrain_iters) + tuple([value for value in info.values()]) log = 'pre_D:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) elif mode == 'patchGAN': for iteration in range(hps.patch_iters): #=======train D=========# for step in range(hps.n_patch_steps): data = next(self.data_loader) c, x = self.permute_data(data) ## encode enc = self.encode_step(x) # sample c c_prime = self.sample_c(x.size(0)) # generator x_tilde = self.gen_step(enc, c_prime) # discriminstor w_dis, real_logits, gp = self.patch_step(x, x_tilde, is_dis=True) # aux classification loss loss_clf = self.cal_loss(real_logits, c) loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp reset_grad([self.PatchDiscriminator]) loss.backward() grad_clip([self.PatchDiscriminator], self.hps.max_grad_norm) self.patch_opt.step() # calculate acc acc = cal_acc(real_logits, c) info = { f'{flag}/w_dis': w_dis.item(), f'{flag}/gp': gp.item(), f'{flag}/real_loss_clf': loss_clf.item(), f'{flag}/real_acc': acc, } slot_value = (step, iteration+1, hps.patch_iters) + tuple([value for value in info.values()]) log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) #=======train G=========# data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # sample c c_prime = self.sample_c(x.size(0)) # generator x_tilde = self.gen_step(enc, c_prime) # discriminstor loss_adv, fake_logits = self.patch_step(x, x_tilde, is_dis=False) # aux classification loss loss_clf = self.cal_loss(fake_logits, c_prime) loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv reset_grad([self.Generator]) loss.backward() grad_clip([self.Generator], self.hps.max_grad_norm) self.gen_opt.step() # calculate acc acc = cal_acc(fake_logits, c_prime) info = { f'{flag}/loss_adv': loss_adv.item(), f'{flag}/fake_loss_clf': loss_clf.item(), f'{flag}/fake_acc': acc, } slot_value = (iteration+1, hps.patch_iters) + tuple([value for value in info.values()]) log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items():
else: current_alpha = hps.alpha_enc #==================train D==================# for step in range(hps.n_latent_steps): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) loss = hps.alpha_dis * loss_clf # update reset_grad([self.SpeakerClassifier]) loss.backward() grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = cal_acc(logits, c) info = { f'{flag}/D_loss_clf': loss_clf.item(), f'{flag}/D_acc': acc, } slot_value = (step, iteration + 1, hps.iters) + tuple([value for value in info.values()]) log = 'D-%d:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) #==================train G==================# data = next(self.data_loader)
def train(self, model_path, flag='train', mode='train'): if not os.path.isdir(model_path): os.makedirs(model_path) os.chmod(model_path, 0o755) model_path = os.path.join(model_path, 'model.pkl') # load hyperparams hps = self.hps if mode == 'pretrain_G': for iteration in range(2200): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) x_tilde = self.decode_step(enc, c) loss_rec = torch.mean(torch.abs(x_tilde - x)) reset_grad([self.Encoder, self.Decoder]) loss_rec.backward() grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm) self.ae_opt.step() # tb info info = { f'{flag}/pre_loss_rec': loss_rec.item(), } slot_value = (iteration + 1, 2200) + tuple( [value for value in info.values()]) log = 'pre_G:[%06d/%06d], loss_rec=%.3f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) elif mode == 'pretrain_D': for iteration in range(2200): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) # update reset_grad([self.SpeakerClassifier]) loss_clf.backward() grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = cal_acc(logits, c) info = { f'{flag}/pre_loss_clf': loss_clf.item(), f'{flag}/pre_acc': acc, } slot_value = (iteration + 1, 2200) + tuple( [value for value in info.values()]) log = 'pre_D:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) elif mode == 'patchGAN': for iteration in range(1100): #=======train D=========# for step in range(hps.n_patch_steps): data = next(self.data_loader) c, x = self.permute_data(data) ## encode enc = self.encode_step(x) # sample c c_prime = self.sample_c(x.size(0)) # generator x_tilde = self.gen_step(enc, c_prime) # discriminstor w_dis, real_logits, gp = self.patch_step(x, x_tilde, is_dis=True) # aux classification loss loss_clf = self.cal_loss(real_logits, c) loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp reset_grad([self.PatchDiscriminator]) loss.backward() grad_clip([self.PatchDiscriminator], self.hps.max_grad_norm) self.patch_opt.step() # calculate acc acc = cal_acc(real_logits, c) info = { f'{flag}/w_dis': w_dis.item(), f'{flag}/gp': gp.item(), f'{flag}/real_loss_clf': loss_clf.item(), f'{flag}/real_acc': acc, } slot_value = (step, iteration + 1, 1100) + tuple( [value for value in info.values()]) log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) #=======train G=========# data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # sample c c_prime = self.sample_c(x.size(0)) # generator x_tilde = self.gen_step(enc, c_prime) # discriminstor loss_adv, fake_logits = self.patch_step(x, x_tilde, is_dis=False) # aux classification loss loss_clf = self.cal_loss(fake_logits, c_prime) loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv reset_grad([self.Generator]) loss.backward() grad_clip([self.Generator], self.hps.max_grad_norm) self.gen_opt.step() # calculate acc acc = cal_acc(fake_logits, c_prime) info = { f'{flag}/loss_adv': loss_adv.item(), f'{flag}/fake_loss_clf': loss_clf.item(), f'{flag}/fake_acc': acc, } slot_value = (iteration + 1, 1100) + tuple( [value for value in info.values()]) log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if iteration % 1000 == 0 or iteration + 1 == hps.patch_iters: self.save_model(model_path, iteration + hps.iters) elif mode == 'train': for iteration in range(1100): # calculate current alpha if iteration < hps.lat_sched_iters: current_alpha = hps.alpha_enc * (iteration / hps.lat_sched_iters) else: current_alpha = hps.alpha_enc #==================train D==================# for step in range(hps.n_latent_steps): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) loss = hps.alpha_dis * loss_clf # update reset_grad([self.SpeakerClassifier]) loss.backward() grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = cal_acc(logits, c) info = { f'{flag}/D_loss_clf': loss_clf.item(), f'{flag}/D_acc': acc, } slot_value = (step, iteration + 1, 1100) + tuple( [value for value in info.values()]) log = 'D-%d:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) #==================train G==================# data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # decode x_tilde = self.decode_step(enc, c) loss_rec = torch.mean(torch.abs(x_tilde - x)) # classify speaker logits = self.clf_step(enc) acc = cal_acc(logits, c) loss_clf = self.cal_loss(logits, c) # maximize classification loss loss = loss_rec - current_alpha * loss_clf reset_grad([self.Encoder, self.Decoder]) loss.backward() grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm) self.ae_opt.step() info = { f'{flag}/loss_rec': loss_rec.item(), f'{flag}/G_loss_clf': loss_clf.item(), f'{flag}/alpha': current_alpha, f'{flag}/G_acc': acc, } slot_value = (iteration + 1, 1100) + tuple( [value for value in info.values()]) log = 'G:[%06d/%06d], loss_rec=%.3f, loss_clf=%.2f, alpha=%.2e, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if iteration % 1000 == 0 or iteration + 1 == hps.iters: self.save_model(model_path, iteration)
def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train, nsteps, ent_coef, vf_coef, max_grad_norm, cell=256, sv_M=32, algo='regular', ib_alpha=1e-3): sess = tf_util.make_session() act_model = policy(sess, ob_space, ac_space, nbatch_act, 1, 1, cell=cell, M=sv_M, model='step_model', algo=algo) train_model = policy(sess, ob_space, ac_space, nbatch_train, 1, nsteps, cell=cell, M=sv_M, model='train_model', algo=algo) A = train_model.wpdtype.sample_placeholder([None]) ADV = tf.placeholder(tf.float32, [None]) R = tf.placeholder(tf.float32, [None]) OLDNEGLOGPAC = tf.placeholder(tf.float32, [None]) OLDNEGLOGPAC_expand = tf.placeholder(tf.float32, [None, sv_M]) OLDVPRED = tf.placeholder(tf.float32, [None]) OLDVPRED_expand = tf.placeholder(tf.float32, [None, sv_M]) LR = tf.placeholder(tf.float32, []) CLIPRANGE = tf.placeholder(tf.float32, []) if algo == 'use_svib_uniform' or algo == 'use_svib_gaussian': def expand_placeholder(X, M=sv_M): return tf.tile(tf.expand_dims(X, axis=-1), [1, M]) A_expand, R_expand = expand_placeholder(A), expand_placeholder(R) neglogpac_expand = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=train_model.wpi_expand, labels=A_expand) #shape=[nbatch, sv_M] entropy_expand = tf.reduce_mean(cat_entropy( train_model.wpi_expand), axis=-1) #shape=[nbatch] vpred_expand = train_model.wvf_expand[:, :, 0] vpredclipped_expand = OLDVPRED_expand + tf.clip_by_value( train_model.wvf_expand[:, :, 0] - OLDVPRED_expand, -CLIPRANGE, CLIPRANGE) vf_loss1_expand = tf.square(vpred_expand - R_expand) vf_loss2_expand = tf.square(vpredclipped_expand - R_expand) vf_loss_expand = .5 * tf.reduce_mean(tf.maximum( vf_loss1_expand, vf_loss2_expand), axis=-1) #shape = [nbatch] ratio_expand = tf.exp(OLDNEGLOGPAC_expand - neglogpac_expand) ADV_expand = R_expand - OLDVPRED_expand # ADV_expand_mean, ADV_expand_var = tf.nn.moments(ADV_expand, axes=0, keep_dims=True)#shape = [1,sv_M] ADV_expand_mean, ADV_expand_var = tf.nn.moments( ADV_expand, axes=[0, 1]) #shape = [1,sv_M] ADV_expand_normal = (ADV_expand - ADV_expand_mean) / ( tf.sqrt(ADV_expand_var) + 1e-8) pg_losses_expand = -ADV_expand_normal * ratio_expand pg_losses2_expand = -ADV_expand_normal * tf.clip_by_value( ratio_expand, 1. - CLIPRANGE, 1. + CLIPRANGE) pg_loss_expand = tf.reduce_mean(tf.maximum(pg_losses_expand, pg_losses2_expand), axis=-1) J_theta = -(pg_loss_expand + vf_coef * vf_loss_expand - ent_coef * entropy_expand) loss_expand = -J_theta / float(nbatch_train) pg_loss_expand_ = tf.reduce_mean(pg_loss_expand) vf_loss_expand_ = tf.reduce_mean(vf_loss_expand) entropy_expand_ = tf.reduce_mean(entropy_expand) log_p_grads = tf.gradients( J_theta / np.sqrt(ib_alpha), [train_model.wh_expand])[0] #shape=[nbatch, sv_M, cell] if algo == 'use_svib_gaussian': mean, var = tf.nn.moments( train_model.wh_expand, axes=1, keep_dims=True) #shape=[nbatch, 1,cell] gaussian_grad = -(train_model.wh_expand - mean) / ( float(sv_M) * (var + 1e-3)) log_p_grads += 5e-3 * ( tf_l2norm(log_p_grads, axis=-1, keep_dims=True) / tf_l2norm(gaussian_grad, axis=-1, keep_dims=True)) * gaussian_grad sv_grads = tf.constant(0., tf.float32, shape=[nbatch_train, 0, cell]) exploit_total_norm_square = 0 explore_total_norm_square = 0 explore_coef = 1. if env_name == 'SeaquestNoFrameskip-v4': explore_coef = 0.01 elif env_name in [ 'AirRaidNoFrameskip-v4,' 'BreakoutNoFrameskip-v4', 'AtlantisNoFrameskip-v4', 'StarGunnerNoFrameskip-v4', 'AsteroidsNoFrameskip-v4', 'YarsRevengeNoFrameskip-v4' ]: explore_coef = 0. print('env_name:', env_name, 'explore_coef: ', explore_coef) for i in range(sv_M): exploit = tf.reduce_sum(train_model.rpf_matrix[:, :, i:i + 1] * log_p_grads, axis=1) explore = np.sqrt( ib_alpha) * explore_coef * train_model.rpf_grads[:, i, :] exploit_total_norm_square += tf.square( tf_l2norm(exploit, axis=-1, keep_dims=False)) explore_total_norm_square += tf.square( tf_l2norm(explore, axis=-1, keep_dims=False)) sv_grad = exploit + explore #shape=[nbatch, cell] sv_grads = tf.concat( [sv_grads, tf.expand_dims(sv_grad, axis=1)], axis=1) SV_GRADS = tf.placeholder(tf.float32, [nbatch_train, sv_M, cell]) repr_loss = tf.reduce_mean(SV_GRADS * train_model.wh_expand, axis=1) #shape=[nbatch,cell] repr_loss = -tf.reduce_mean(tf.reduce_sum( repr_loss, axis=-1)) #max optimization problem to minimization problem #op for debugging and visualization exploit_explore_ratio = tf.sqrt( exploit_total_norm_square / tf.maximum(explore_total_norm_square, 0.01))[0] # rpf_mat = tf.expand_dims(train_model.rpf_matrix, axis=-1) # log_p_grads_tile = tf.tile(tf.expand_dims(log_p_grads, axis=2), [1,1,sv_M,1]) # exploit = tf.reduce_sum(rpf_mat*log_p_grads_tile, axis=1) # explore = np.sqrt(ib_alpha) * train_model.rpf_grads # sv_grads = exploit + explore # ind = 1 # exploit = tf.reduce_sum(train_model.rpf_matrix[:, :, i:i + 1] * log_p_grads, axis=1) # explore = train_model.rpf_grads[:, i, :] # clip_coef = tf_l2norm(exploit, axis=-1, keep_dims=True) # explore_norm = tf_l2norm(explore, axis=-1, keep_dims=True) # explore = explore * 1e-2 * clip_coef / tf.maximum(explore_norm, clip_coef) # sv_grad = exploit + np.sqrt(ib_alpha) * explore # shape=[nbatch, cell] grads_expand, grad_norm_expand = grad_clip(loss_expand, max_grad_norm, ['model/worker_module']) trainer_expand = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5) _train_expand = trainer_expand.apply_gradients(grads_expand) repr_grads, repr_global_norm = grad_clip( repr_loss, max_grad_norm, ['model/ordinary_encoder']) repr_trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5) _repr_train = repr_trainer.apply_gradients(repr_grads) else: print('env_name:', env_name) neglogpac = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=train_model.wpi, labels=A) entropy = tf.reduce_mean(cat_entropy(train_model.wpi)) vpred = train_model.wvf[:, 0] vpredclipped = OLDVPRED + tf.clip_by_value( train_model.wvf[:, 0] - OLDVPRED, -CLIPRANGE, CLIPRANGE) vf_losses1 = tf.square(vpred - R) vf_losses2 = tf.square(vpredclipped - R) vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2)) ratio = tf.exp(OLDNEGLOGPAC - neglogpac) pg_losses = -ADV * ratio pg_losses2 = -ADV * tf.clip_by_value(ratio, 1.0 - CLIPRANGE, 1.0 + CLIPRANGE) pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2)) loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef grads, grad_norm = grad_clip( loss, max_grad_norm, ['model/worker_module', 'model/ordinary_encoder']) trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5) _train = trainer.apply_gradients(grads) with tf.variable_scope('model'): params = tf.trainable_variables() def generate_old_expand_data(obs, noises, masks, actions, states=None): noises_expand = sess.run(train_model.noise_expand) repr_td_map = { train_model.wX: obs, train_model.istraining: False, A: actions, train_model.noise_expand: noises_expand, train_model.NOISE_KEEP: noises } if states is not None: repr_td_map[train_model.wS] = states repr_td_map[train_model.wM] = masks neglogpacs_expand, vpreds_expand = \ sess.run([neglogpac_expand, vpred_expand], feed_dict=repr_td_map) shape = noises_expand.shape noises_expand = noises_expand.reshape(nbatch_train, sv_M - 1, *shape[1:]) return [noises_expand, neglogpacs_expand, vpreds_expand] def train(lr, cliprange, obs, noises, returns, masks, actions, values, neglogpacs, noises_expand=None, neglogpacs_expand=None, vpreds_expand=None, states=None): advs = returns - values advs = (advs - advs.mean()) / (advs.std() + 1e-8) if algo == 'use_svib_uniform' or algo == 'use_svib_gaussian': shape = noises_expand.shape noises_expand_ = noises_expand.reshape( nbatch_train * (sv_M - 1), *shape[2:]) # print(noises_expand_.shape) repr_td_map = { train_model.wX: obs, train_model.istraining: True, A: actions, R: returns, LR: lr, CLIPRANGE: cliprange, train_model.noise_expand: noises_expand_, train_model.NOISE_KEEP: noises, OLDNEGLOGPAC_expand: neglogpacs_expand, OLDVPRED_expand: vpreds_expand } rl_td_map = { train_model.istraining: True, A: actions, R: returns, LR: lr, CLIPRANGE: cliprange } if states is not None: if algo == 'use_svib_uniform' or algo == 'use_svib_gaussian': repr_td_map[train_model.wS] = states repr_td_map[train_model.wM] = masks rl_td_map[train_model.wS] = states rl_td_map[train_model.wM] = masks if algo == 'use_svib_uniform' or algo == 'use_svib_gaussian': sv_gradients, whs_expand, ir_ratio = sess.run( [sv_grads, train_model.wh_expand, exploit_explore_ratio], feed_dict=repr_td_map) rl_td_map[OLDNEGLOGPAC_expand], rl_td_map[ OLDVPRED_expand], rl_td_map[ train_model. wh_expand] = neglogpacs_expand, vpreds_expand, whs_expand value_loss, policy_loss, policy_entropy, _, rl_grad_norm = sess.run( [ vf_loss_expand_, pg_loss_expand_, entropy_expand_, _train_expand, grad_norm_expand ], feed_dict=rl_td_map) repr_td_map[SV_GRADS] = sv_gradients repr_grad_norm, represent_loss, __ = sess.run( [repr_global_norm, repr_loss, _repr_train], feed_dict=repr_td_map) else: rl_td_map[train_model.wX], rl_td_map[ train_model. noise] = obs, noises #noise won't be used when algo is 'regular' rl_td_map[OLDNEGLOGPAC], rl_td_map[OLDVPRED], rl_td_map[ ADV] = neglogpacs, values, advs value_loss, policy_loss, policy_entropy, _, rl_grad_norm = sess.run( [vf_loss, pg_loss, entropy, _train, grad_norm], feed_dict=rl_td_map) represent_loss, rpf_norm_, rpf_grad_norm_, sv_gradients, ir_ratio, repr_grad_norm = 0., 0., 0., 0., 0, 0. return policy_loss, value_loss, policy_entropy, represent_loss, ir_ratio, rl_grad_norm, repr_grad_norm self.loss_names = [ 'policy_loss', 'value_loss', 'policy_entropy', 'represent_loss', 'exploit_explore_ratio', 'rl_grad_norm', 'repr_grad_norm' ] def save(save_path): ps = sess.run(params) make_path(osp.dirname(save_path)) joblib.dump(ps, save_path) def load(load_path): loaded_params = joblib.load(load_path) restores = [] for p, loaded_p in zip(params, loaded_params): restores.append(p.assign(loaded_p)) sess.run(restores) # If you want to load weights, also save/load observation scaling inside VecNormalize self.generate_old_expand_data = generate_old_expand_data self.train = train self.train_model = train_model self.act_model = act_model self.step = act_model.step self.value = act_model.wvalue self.initial_state = act_model.w_initial_state self.save = save self.load = load tf.global_variables_initializer().run(session=sess) #pylint: disable=E1101
def train(self, model_path, flag='train', mode='train'): # load hyperparams hps = self.hps if mode == 'pretrain_G': for iteration in range(hps.enc_pretrain_iters): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) x_tilde = self.decode_step(enc, c) loss_rec = torch.mean(torch.abs(x_tilde - x)) reset_grad([self.Encoder, self.Decoder]) loss_rec.backward() grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm) self.ae_opt.step() # tb info info = { f'{flag}/pre_loss_rec': loss_rec.item(), } slot_value = (iteration + 1, hps.enc_pretrain_iters) + tuple( [value for value in info.values()]) log = 'pre_G:[%06d/%06d], loss_rec=%.3f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) elif mode == 'pretrain_D': for iteration in range(hps.dis_pretrain_iters): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) # update reset_grad([self.SpeakerClassifier]) loss_clf.backward() grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = cal_acc(logits, c) info = { f'{flag}/pre_loss_clf': loss_clf.item(), f'{flag}/pre_acc': acc, } slot_value = (iteration + 1, hps.dis_pretrain_iters) + tuple( [value for value in info.values()]) log = 'pre_D:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) elif mode == 'patchGAN': for iteration in range(hps.patch_iters): #=======train D=========# for step in range(hps.n_patch_steps): data = next(self.data_loader) c, x = self.permute_data(data) ## encode enc = self.encode_step(x) # sample c c_prime = self.sample_c(x.size(0)) # generator x_tilde = self.gen_step(enc, c_prime) # discriminstor w_dis, real_logits, gp = self.patch_step(x, x_tilde, is_dis=True) # aux classification loss loss_clf = self.cal_loss(real_logits, c) loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp reset_grad([self.PatchDiscriminator]) loss.backward() grad_clip([self.PatchDiscriminator], self.hps.max_grad_norm) self.patch_opt.step() # calculate acc acc = cal_acc(real_logits, c) info = { f'{flag}/w_dis': w_dis.item(), f'{flag}/gp': gp.item(), f'{flag}/real_loss_clf': loss_clf.item(), f'{flag}/real_acc': acc, } slot_value = (step, iteration + 1, hps.patch_iters) + tuple( [value for value in info.values()]) log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) #=======train G=========# data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # sample c c_prime = self.sample_c(x.size(0)) # generator x_tilde = self.gen_step(enc, c_prime) # discriminstor loss_adv, fake_logits = self.patch_step(x, x_tilde, is_dis=False) # aux classification loss loss_clf = self.cal_loss(fake_logits, c_prime) loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv reset_grad([self.Generator]) loss.backward() grad_clip([self.Generator], self.hps.max_grad_norm) self.gen_opt.step() # calculate acc acc = cal_acc(fake_logits, c_prime) info = { f'{flag}/loss_adv': loss_adv.item(), f'{flag}/fake_loss_clf': loss_clf.item(), f'{flag}/fake_acc': acc, } slot_value = (iteration + 1, hps.patch_iters) + tuple( [value for value in info.values()]) log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration) #===================== Train G =====================# data = next(self.data_loader) (c_i, c_j), (x_i_t, x_i_tk, x_i_prime, x_j) = self.permute_data(data) # encode enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step( x_i_t, x_i_tk, x_i_prime, x_j) # decode x_tilde = self.decode_step(enc_i_t, c_i) loss_rec = torch.mean(torch.abs(x_tilde - x_i_t)) # latent discriminate loss_adv = self.latent_discriminate_step(enc_i_t, enc_i_tk, enc_i_prime, enc_j, is_dis=False) ae_loss = loss_rec + current_alpha * loss_adv reset_grad([self.Encoder, self.Decoder]) retain_graph = True if hps.n_patch_steps > 0 else False ae_loss.backward(retain_graph=retain_graph) grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm) self.ae_opt.step() info = { f'{flag}/loss_rec': loss_rec.data[0], f'{flag}/loss_adv': loss_adv.data[0], f'{flag}/alpha': current_alpha, } slot_value = (iteration + 1, hps.iters) + tuple( [value for value in info.values()]) log = 'G:[%06d/%06d], loss_rec=%.2f, loss_adv=%.2f, alpha=%.2e' print(log % slot_value) for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) # patch discriminate if hps.n_patch_steps > 0 and iteration >= hps.patch_start_iter: c_sample = self.sample_c(x_i_t.size(0)) x_tilde = self.decode_step(enc_i_t, c_sample) patch_w_dis, real_logits, fake_logits = \ self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=False) patch_loss = hps.beta_dec * patch_w_dis + hps.beta_clf * c_loss reset_grad([self.Decoder]) patch_loss.backward() grad_clip([self.Decoder], self.hps.max_grad_norm) self.decoder_opt.step() info = { f'{flag}/loss_rec': loss_rec.item(), f'{flag}/G_loss_clf': loss_clf.item(), f'{flag}/alpha': current_alpha, f'{flag}/G_acc': acc, } slot_value = (iteration + 1, hps.iters) + tuple( [value for value in info.values()]) log = 'G:[%06d/%06d], loss_rec=%.3f, loss_clf=%.2f, alpha=%.2e, acc=%.2f' print(log % slot_value) if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if iteration % 1000 == 0 or iteration + 1 == hps.iters: self.save_model(model_path, iteration)
def __init__(self, policy, ob_space, ac_space, nenvs, master_ts = 1, worker_ts = 30, ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4, cell = 256, alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear', algo='regular', beta=1e-3): print('Create Session') gpu_options = tf.GPUOptions(allow_growth=True) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) nact = ac_space.n nbatch = nenvs*master_ts*worker_ts A = tf.placeholder(tf.int32, [nbatch]) ADV = tf.placeholder(tf.float32, [nbatch]) R = tf.placeholder(tf.float32, [nbatch]) LR = tf.placeholder(tf.float32, []) step_model = policy(sess, ob_space, ac_space, nenvs, 1, 1, cell = cell, model='step_model', algo=algo) train_model = policy(sess, ob_space, ac_space, nbatch, master_ts, worker_ts, model='train_model', algo=algo) print('model_setting_done') #loss construction neglogpac = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.wpi, labels=A) pg_loss = tf.reduce_mean(ADV * neglogpac) vf_loss = tf.reduce_mean(mse(tf.squeeze(train_model.wvf), R)) entropy = tf.reduce_mean(cat_entropy(train_model.wpi)) pg_loss = pg_loss - entropy * ent_coef print('algo: ', algo, 'max_grad_norm: ', str(max_grad_norm)) try: if algo == 'regular': loss = pg_loss + vf_coef * vf_loss elif algo == 'VIB': ''' implement VIB here, apart from the vf_loss and pg_loss, there should be a third loss, the kl_loss = ds.kl_divergence(model.encoding, prior), where prior is a Gaussian distribution with mu=0, std=1 the final loss should be pg_loss + vf_coef * vf_loss + beta*kl_loss ''' prior = ds.Normal(0.0, 1.0) kl_loss = tf.reduce_mean(ds.kl_divergence(train_model.encoding, prior)) loss = pg_loss + vf_coef * vf_loss + beta*kl_loss # pass else: raise Exception('Algorithm not exists') except Exception as e: print(e) grads, global_norm = grad_clip(loss, max_grad_norm, ['model']) trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon) _train = trainer.apply_gradients(grads) lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule) def train(wobs, whs, states, rewards, masks, actions, values): advs = rewards - values for step in range(len(whs)): cur_lr = lr.value() td_map = {train_model.wX:wobs, A:actions, ADV:advs, R:rewards, LR:cur_lr} if states is not None: td_map[train_model.wS] = states td_map[train_model.wM] = masks ''' you can add and run additional loss for VIB here for debugging, such as kl_loss ''' tloss, value_loss, policy_loss, policy_entropy, _ = sess.run( [loss, vf_loss, pg_loss, entropy, _train], feed_dict=td_map ) return tloss, value_loss, policy_loss, policy_entropy params = find_trainable_variables("model") def save(save_path): ps = sess.run(params) make_path(osp.dirname(save_path)) joblib.dump(ps, save_path) def load(load_path): loaded_params = joblib.load(load_path) restores = [] for p, loaded_p in zip(params, loaded_params): restores.append(p.assign(loaded_p)) ps = sess.run(restores) self.train_model = train_model self.step_model = step_model self.step = step_model.step self.value = step_model.wvalue self.get_wh = step_model.get_wh self.initial_state = step_model.w_initial_state self.train = train self.save = save self.load = load tf.global_variables_initializer().run(session=sess)
def train(self, model_path, flag='train', mode='train', target_guided=False): # load hyperparams hps = self.hps if mode == 'pretrain_AE': for iteration in range(hps.enc_pretrain_iters): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc_act, enc = self.encode_step(x) x_dec = self.decode_step(enc_act, c) loss_rec = torch.mean(torch.abs(x_dec - x)) reset_grad([self.Encoder, self.Decoder]) loss_rec.backward() grad_clip([self.Encoder, self.Decoder], hps.max_grad_norm) self.ae_opt.step() # tb info info = { f'{flag}/pre_loss_rec': loss_rec.item(), } slot_value = (iteration + 1, hps.enc_pretrain_iters) + tuple( [value for value in info.values()]) log = 'pre_AE:[%06d/%06d], loss_rec=%.3f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 'ae', iteration + 1) print() elif mode == 'pretrain_C': for iteration in range(hps.dis_pretrain_iters): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc_act, enc = self.encode_step(x) # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) # update reset_grad([self.SpeakerClassifier]) loss_clf.backward() grad_clip([self.SpeakerClassifier], hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = self.cal_acc(logits, c) info = { f'{flag}/pre_loss_clf': loss_clf.item(), f'{flag}/pre_acc': acc, } slot_value = (iteration + 1, hps.dis_pretrain_iters) + tuple( [value for value in info.values()]) log = 'pre_C:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 'c', iteration + 1) print() elif mode == 'train': for iteration in range(hps.iters): # calculate current alpha if iteration < hps.lat_sched_iters: current_alpha = hps.alpha_enc * (iteration / hps.lat_sched_iters) else: current_alpha = hps.alpha_enc #==================train D==================# for step in range(hps.n_latent_steps): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc_act, enc = self.encode_step(x) # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) loss = hps.alpha_dis * loss_clf # update reset_grad([self.SpeakerClassifier]) loss.backward() grad_clip([self.SpeakerClassifier], hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = self.cal_acc(logits, c) info = { f'{flag}/D_loss_clf': loss_clf.item(), f'{flag}/D_acc': acc, } slot_value = (step, iteration + 1, hps.iters) + tuple( [value for value in info.values()]) log = 'D-%d:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) #==================train G==================# data = next(self.data_loader) c, x = self.permute_data(data) # encode enc_act, enc = self.encode_step(x) # decode x_dec = self.decode_step(enc_act, c) loss_rec = torch.mean(torch.abs(x_dec - x)) # classify speaker logits = self.clf_step(enc) acc = self.cal_acc(logits, c) loss_clf = self.cal_loss(logits, c) # maximize classification loss loss = loss_rec - current_alpha * loss_clf reset_grad([self.Encoder, self.Decoder]) loss.backward() grad_clip([self.Encoder, self.Decoder], hps.max_grad_norm) self.ae_opt.step() info = { f'{flag}/loss_rec': loss_rec.item(), f'{flag}/G_loss_clf': loss_clf.item(), f'{flag}/alpha': current_alpha, f'{flag}/G_acc': acc, } slot_value = (iteration + 1, hps.iters) + tuple( [value for value in info.values()]) log = 'G:[%06d/%06d], loss_rec=%.3f, loss_clf=%.2f, alpha=%.2e, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 's1', iteration + 1) print() elif mode == 'patchGAN': for iteration in range(hps.patch_iters): #==================train D==================# for step in range(hps.n_patch_steps): data_s = next(self.source_loader) data_t = next(self.target_loader) _, x_s = self.permute_data(data_s) c_t, x_t = self.permute_data(data_t) # encode enc_act, _ = self.encode_step(x_s) # generator x_dec = self.gen_step(enc_act, c_t) # discriminstor w_dis, real_logits, gp = self.patch_step(x_t, x_dec, is_dis=True) # aux classification loss loss_clf = self.cal_loss(real_logits, c_t, shift=True) loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp reset_grad([self.PatchDiscriminator]) loss.backward() grad_clip([self.PatchDiscriminator], hps.max_grad_norm) self.patch_opt.step() # calculate acc acc = self.cal_acc(real_logits, c_t, shift=True) info = { f'{flag}/w_dis': w_dis.item(), f'{flag}/gp': gp.item(), f'{flag}/real_loss_clf': loss_clf.item(), f'{flag}/real_acc': acc, } slot_value = (step, iteration + 1, hps.patch_iters) + tuple( [value for value in info.values()]) log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) #==================train G==================# data_s = next(self.source_loader) data_t = next(self.target_loader) _, x_s = self.permute_data(data_s) c_t, x_t = self.permute_data(data_t) # encode enc_act, _ = self.encode_step(x_s) # generator x_dec = self.gen_step(enc_act, c_t) # discriminstor loss_adv, fake_logits = self.patch_step(x_t, x_dec, is_dis=False) # aux classification loss loss_clf = self.cal_loss(fake_logits, c_t, shift=True) loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv reset_grad([self.Generator]) loss.backward() grad_clip([self.Generator], hps.max_grad_norm) self.gen_opt.step() if target_guided: # teacher forcing enc_tf, _ = self.encode_step(x_t) x_dec_tf = self.gen_step(enc_tf, c_t) loss_rec = torch.mean(torch.abs(x_dec_tf - x_t)) reset_grad([self.Generator]) loss_rec.backward() self.gen_opt.step() # calculate acc acc = self.cal_acc(fake_logits, c_t, shift=True) info = { f'{flag}/loss_adv': loss_adv.item(), f'{flag}/fake_loss_clf': loss_clf.item(), f'{flag}/fake_acc': acc, f'{flag}/tg_rec': loss_rec.item() if target_guided else 0.000, } slot_value = (iteration + 1, hps.patch_iters) + tuple( [value for value in info.values()]) log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f, tg_rec=%.3f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 's2', iteration + 1) print() elif mode == 'autolocker': criterion = torch.nn.BCELoss() for iteration in range(hps.patch_iters): #==================train G==================# data_s = next(self.source_loader) data_t = next(self.target_loader) _, x_s = self.permute_data(data_s) c_t, x_t = self.permute_data(data_t) # encode enc_act, _ = self.encode_step(x_s) # decode residual_output = self.gen_step(enc_act, c_t) # re-encode re_enc, _ = self.encode_step(residual_output) # re-encode loss loss_reenc = criterion(re_enc, enc_act.data) reset_grad([self.Encoder, self.Decoder, self.Generator]) loss_reenc.backward() grad_clip([self.Generator], hps.max_grad_norm) self.gen_opt.step() if target_guided: # teacher forcing enc_tf, _ = self.encode_step(x_t) x_dec_tf = self.gen_step(enc_tf, c_t) loss_rec = torch.mean(torch.abs(x_dec_tf - x_t)) reset_grad([self.Encoder, self.Decoder, self.Generator]) loss_rec.backward() self.gen_opt.step() # calculate acc info = { f'{flag}/re_enc': loss_reenc.item(), f'{flag}/tg_rec': loss_rec.item() if target_guided else 0.000, } slot_value = (iteration + 1, hps.patch_iters) + tuple( [value for value in info.values()]) log = 'patch_G:[%06d/%06d], re_enc=%.3f, tg_rec=%.3f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 's2', iteration + 1) print() elif mode == 't_classify': for iteration in range(hps.tclf_iters): #======train target classifier======# data = next(self.data_loader) c, x = self.permute_data(data) c[c < 100] = 102 # classification logits = self.tclf_step(x) # classification loss loss = self.cal_loss(logits, c - self.shift_c) reset_grad([self.TargetClassifier]) loss.backward() grad_clip([self.TargetClassifier], hps.max_grad_norm) self.tclf_opt.step() # calculate acc acc = self.cal_acc(logits, c - self.shift_c) info = { f'{flag}/acc': acc, } slot_value = (iteration + 1, hps.tclf_iters) + tuple( [value for value in info.values()]) log = 'Target Classifier:[%05d/%05d], acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 'tclf', iteration + 1) print() elif mode == 'train_Tacotron': assert self.g_mode == 'tacotron' criterion = TacotronLoss() self.Encoder.eval() for iteration in range(hps.tacotron_iters): #======train tacotron======# cur_lr = learning_rate_decay(init_lr=0.002, global_step=iteration) for param_group in self.gen_opt.param_groups: param_group['lr'] = cur_lr data = next(self.data_loader) c, x, m = self.permute_data(data, load_mel=True) # encode enc_act, enc = self.encode_step(x) # tacotron synthesis m_dec, x_dec = self.tacotron_step(enc_act.data, m, c) # reconstruction loss loss_rec = criterion([m_dec, x_dec], [m, x]) reset_grad([self.Generator]) loss_rec.backward() grad_clip([self.Generator], hps.max_grad_norm) self.gen_opt.step() # tb info info = { f'{flag}/tacotron_loss_rec': loss_rec.item(), f'{flag}/tacotron_lr': cur_lr, } slot_value = (iteration + 1, hps.tacotron_iters) + tuple( [value for value in info.values()]) log = 'train_Tacotron:[%06d/%06d], loss_rec=%.3f, lr=%.2e' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 't', iteration + 1) print() else: raise NotImplementedError()
def train(self, model_path, flag='train', mode='train'): # load hyperparams hps = self.hps if mode == 'pretrain_AE': for iteration in range(hps.enc_pretrain_iters): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) x_tilde = self.decode_step(enc, c) loss_rec = torch.mean(torch.abs(x_tilde - x)) reset_grad([self.Encoder, self.Decoder]) loss_rec.backward() grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm) self.ae_opt.step() # tb info info = { f'{flag}/pre_loss_rec': loss_rec.item(), } slot_value = (iteration + 1, hps.enc_pretrain_iters) + \ tuple([value for value in info.values()]) log = 'pre_AE:[%06d/%06d], loss_rec=%.3f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 'ae', iteration + 1) print() elif mode == 'pretrain_C': for iteration in range(hps.dis_pretrain_iters): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) # update reset_grad([self.SpeakerClassifier]) loss_clf.backward() grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = self.cal_acc(logits, c) info = { f'{flag}/pre_loss_clf': loss_clf.item(), f'{flag}/pre_acc': acc, } slot_value = (iteration + 1, hps.dis_pretrain_iters) + \ tuple([value for value in info.values()]) log = 'pre_C:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 'c', iteration + 1) print() elif mode == 'train': for iteration in range(hps.iters): # calculate current alpha if iteration < hps.lat_sched_iters: current_alpha = hps.alpha_enc * \ (iteration / hps.lat_sched_iters) else: current_alpha = hps.alpha_enc #==================train D==================# for step in range(hps.n_latent_steps): data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) _, z_mean, z_log_var = enc kl_loss = (1 + z_log_var - z_mean**2 - torch.exp(z_log_var)).sum(-1) * -.5 kl_loss = kl_loss.sum() # classify speaker logits = self.clf_step(enc) loss_clf = self.cal_loss(logits, c) loss = hps.alpha_dis * loss_clf + kl_loss # update reset_grad([self.SpeakerClassifier]) loss.backward() grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm) self.clf_opt.step() # calculate acc acc = self.cal_acc(logits, c) info = { f'{flag}/D_loss_clf': loss_clf.item(), f'{flag}/D_acc': acc, } slot_value = (step, iteration + 1, hps.iters) + \ tuple([value for value in info.values()]) log = 'D-%d:[%06d/%06d], loss_clf=%.2f, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) #==================train G==================# data = next(self.data_loader) c, x = self.permute_data(data) # encode enc = self.encode_step(x) # decode x_tilde = self.decode_step(enc, c) loss_rec = torch.mean(torch.abs(x_tilde - x)) # classify speaker logits = self.clf_step(enc) acc = self.cal_acc(logits, c) loss_clf = self.cal_loss(logits, c) # maximize classification loss loss = loss_rec - current_alpha * loss_clf reset_grad([self.Encoder, self.Decoder]) loss.backward() grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm) self.ae_opt.step() info = { f'{flag}/loss_rec': loss_rec.item(), f'{flag}/G_loss_clf': loss_clf.item(), f'{flag}/alpha': current_alpha, f'{flag}/G_acc': acc, } slot_value = (iteration + 1, hps.iters) + \ tuple([value for value in info.values()]) log = 'G:[%06d/%06d], loss_rec=%.3f, loss_clf=%.2f, alpha=%.2e, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 's1', iteration + 1) print() elif mode == 'patchGAN': for iteration in range(hps.patch_iters): #==================train D==================# for step in range(hps.n_patch_steps): data_s = next(self.source_loader) data_t = next(self.target_loader) _, x_s = self.permute_data(data_s) c, x_t = self.permute_data(data_t) # encode enc = self.encode_step(x_s) # sample c c_prime = self.sample_c(x_t.size(0)) # generator x_tilde = self.gen_step(enc, c_prime) # discriminstor w_dis, real_logits, gp = self.patch_step(x_t, x_tilde, is_dis=True) # aux classification loss loss_clf = self.cal_loss(real_logits, c, shift=True) loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp reset_grad([self.PatchDiscriminator]) loss.backward() grad_clip([self.PatchDiscriminator], self.hps.max_grad_norm) self.patch_opt.step() # calculate acc acc = self.cal_acc(real_logits, c, shift=True) info = { f'{flag}/w_dis': w_dis.item(), f'{flag}/gp': gp.item(), f'{flag}/real_loss_clf': loss_clf.item(), f'{flag}/real_acc': acc, } slot_value = (step, iteration + 1, hps.patch_iters) + \ tuple([value for value in info.values()]) log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) #==================train G==================# data_s = next(self.source_loader) data_t = next(self.target_loader) _, x_s = self.permute_data(data_s) c, x_t = self.permute_data(data_t) # encode enc = self.encode_step(x_s) # sample c c_prime = self.sample_c(x_t.size(0)) # generator x_tilde = self.gen_step(enc, c_prime) # discriminstor loss_adv, fake_logits = self.patch_step(x_t, x_tilde, is_dis=False) # aux classification loss loss_clf = self.cal_loss(fake_logits, c_prime, shift=True) loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv reset_grad([self.Generator]) loss.backward() grad_clip([self.Generator], self.hps.max_grad_norm) self.gen_opt.step() # calculate acc acc = self.cal_acc(fake_logits, c_prime, shift=True) info = { f'{flag}/loss_adv': loss_adv.item(), f'{flag}/fake_loss_clf': loss_clf.item(), f'{flag}/fake_acc': acc, } slot_value = (iteration + 1, hps.patch_iters) + \ tuple([value for value in info.values()]) log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f' print(log % slot_value, end='\r') if iteration % 100 == 0: for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if (iteration + 1) % 1000 == 0: self.save_model(model_path, 's2', iteration + 1 + hps.iters) print() else: raise NotImplementedError()
def __init__(self, policy, ob_space, ac_space, nenvs, master_ts = 1, worker_ts = 8, ent_coef=0.01, vf_coef=0.5, max_grad_norm=2.5, lr=7e-4, cell=256, ib_alpha=0.04, sv_M=32, algo='use_svib_uniform', alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear'): sess = tf_util.make_session() nact = ac_space.n nbatch = nenvs*master_ts*worker_ts # master what's mean? # A:action, ADV:advantage, R:reward, LR:Learning Rate A = tf.placeholder(tf.int32, [nbatch]) ADV = tf.placeholder(tf.float32, [nbatch]) R = tf.placeholder(tf.float32, [nbatch]) LR = tf.placeholder(tf.float32, []) step_model = policy(sess, ob_space, ac_space, nenvs, 1, 1, cell=cell, M=sv_M, model='step_model', algo=algo) train_model = policy(sess, ob_space, ac_space, nbatch, master_ts, worker_ts, cell = cell, M=sv_M, model='train_model', algo=algo) print('model_setting_done, algorithm:', str(algo)) ''' 可视化互信息,暂时跳过 ''' ib_loss = train_model.mi_xh_loss T = train_model.T_value t_grads, t_global_norm = grad_clip(-vf_coef*ib_loss, max_grad_norm, ['model/T/update_params']) t_trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon) _t_train = t_trainer.apply_gradients(t_grads) T_update_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='model/T/update_params') T_orig_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='model/T/orig_params') reset_update_params = [update_param.assign(orig_param) for update_param, orig_param in zip(T_update_params, T_orig_params)] # rpf_matrix, rpf_grads = rpf_kernel(vf_loss_sv, rpf_h) if algo == 'use_svib_uniform' or algo == 'use_svib_gaussian': def expand_placeholder(X, M=sv_M): return tf.tile(tf.expand_dims(X, axis=-1), [1, M]) A_expand, R_expand = expand_placeholder(A), expand_placeholder(R) neglogpac_expand = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.wpi_expand, labels=A_expand)#shape=[nbatch, sv_M] # pg_loss_expand = tf.reduce_mean(ADV_expand * neglogpac_expand, axis=-1) pg_loss_expand = tf.reduce_mean(tf.stop_gradient(R_expand-train_model.wvf_expand[:,:,0]) * neglogpac_expand, axis=-1) vf_loss_expand = tf.reduce_mean(mse(tf.squeeze(train_model.wvf_expand), R_expand), axis=-1) entropy_expand = tf.reduce_mean(cat_entropy(train_model.wpi_expand), axis=-1)#shape=[nbatch] J_theta = -(pg_loss_expand + vf_coef*vf_loss_expand - ent_coef*entropy_expand) loss_expand = -J_theta / float(nbatch) pg_loss_expand_ = tf.reduce_mean(pg_loss_expand) vf_loss_expand_ = tf.reduce_mean(vf_loss_expand) entropy_expand_ = tf.reduce_mean(entropy_expand) loss_expand_ = -tf.reduce_mean(J_theta) print('ib_alpha: ', ib_alpha) log_p_grads = tf.gradients(J_theta/np.sqrt(ib_alpha), [train_model.wh_expand])[0]#shape=[nbatch, sv_M, cell] if algo == 'use_svib_gaussian': mean, var = tf.nn.moments(train_model.wh_expand, axes=1, keep_dims=True)#shape=[nbatch, 1,cell] gaussian_grad = -(train_model.wh_expand - mean)/(float(sv_M) * (var+1e-3)) log_p_grads += 5e-3*(tf_l2norm(log_p_grads, axis=-1, keep_dims=True)/tf_l2norm(gaussian_grad, axis=-1, keep_dims=True))*gaussian_grad sv_grads = tf.constant(0., tf.float32, shape=[nbatch, 0, cell]) for i in range(sv_M): sv_grad = tf.reduce_sum(train_model.rpf_matrix[:, :, i:i+1] * log_p_grads, axis=1) + np.sqrt(ib_alpha)*train_model.rpf_grads[:, i, :]#shape=[nbatch, cell] sv_grads = tf.concat([sv_grads, tf.expand_dims(sv_grad, axis=1)], axis=1) SV_GRADS = tf.placeholder(tf.float32, [nbatch, sv_M, cell]) repr_loss = tf.reduce_mean(SV_GRADS * train_model.wh_expand, axis=1)#shape=[nbatch,cell] repr_loss = -tf.reduce_mean(tf.reduce_sum(repr_loss, axis=-1))#max optimization problem to minimization problem # repr_loss = -tf.reduce_mean(repr_loss, axis=0) # sv_grad_ = tf.reduce_sum(train_model.rpf_matrix[:, :, 2:3] * log_p_grads, axis=1) + train_model.rpf_grads[:, 2, :] # exploit_term = tf.reduce_sum(train_model.rpf_matrix[:, :, 2:3] * log_p_grads, axis=1) # explore_term = train_model.rpf_grads[:, 2, :] grads_expand, global_norm_expand = grad_clip(loss_expand, max_grad_norm, ['model/worker_module']) trainer_expand = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon) _train_expand = trainer_expand.apply_gradients(grads_expand) repr_grads, repr_global_norm = grad_clip(repr_loss, max_grad_norm, ['model/ordinary_encoder']) repr_trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon) _repr_train = repr_trainer.apply_gradients(repr_grads) elif algo == 'sv_a2c': def expand_placeholder(X, M=sv_M): return tf.tile(tf.expand_dims(X, axis=-1), [1, M]) A_expand, R_expand = expand_placeholder(A), expand_placeholder(R) # [40, 32] sigma = tf.constant(1e-5) neglogpac_expand = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.wpi_expand, labels=A_expand) + sigma # [40, 32] pg_loss_expand = tf.reduce_mean(tf.stop_gradient(R_expand - train_model.wvf_expand[:, :, 0]) * neglogpac_expand, axis=-1) # [40, ] vf_loss_sv = tf.expand_dims(mse(tf.squeeze(train_model.wvf_expand), R_expand), axis=-1) # [40, 32, 1] vf_loss_expand = tf.reduce_mean(mse(tf.squeeze(train_model.wvf_expand), R_expand), axis=-1) # [40, ] entropy_expand = tf.reduce_mean(cat_entropy(train_model.wpi_expand), axis=-1) # shape=[nbatch] J_theta = pg_loss_expand + vf_coef * vf_loss_expand - ent_coef * entropy_expand # [40, ] # 为什么要除nbatch loss_expand = J_theta / float(nbatch) # [40, ] pg_loss_expand_ = tf.reduce_mean(pg_loss_expand) vf_loss_expand_ = tf.reduce_mean(vf_loss_expand) # [1] entropy_expand_ = tf.reduce_mean(entropy_expand) loss_expand_ = tf.reduce_mean(J_theta) print('ib_alpha: ', ib_alpha) # mean, var = tf.constant(0., tf.float32, [nbatch, 1, 1]), tf.constant(1, tf.float32, [nbatch, 1, 1]) mean, var = tf.nn.moments(vf_loss_sv, axes=1, keep_dims=True) # [40, 1, 1] # Problem1: guassian gradient cauculate problem log_p_grads = -(vf_loss_sv - mean) / (float(sv_M) * (var)) sv_grads = tf.constant(0., tf.float32, shape=[nbatch, 0, 1]) # [nbatch, m, 1] rpf_h = self.h_coef(vf_loss_sv, sv_M) rpf_matrix, rpf_grads = self.rpf_kernel(vf_loss_sv, rpf_h, sv_M) for i in range(sv_M): # sv_grad = tf.reduce_sum(train_model.rpf_matrix[:, :, i:i+1] * log_p_grads, axis=1) + sqrt(ib_alpha) * train_model.rpf_grads[:, i, :] #shape=[nbatch, cell] sv_grad = tf.reduce_sum(rpf_matrix[:, :, i:i + 1] * log_p_grads, axis=1) + rpf_grads[:, i, :] sv_grads = tf.concat([sv_grads, tf.expand_dims(sv_grad, axis=1)], axis=1) SV_GRADS = tf.placeholder(tf.float32, [nbatch, sv_M, 1]) sv_loss = tf.reduce_mean(SV_GRADS * vf_loss_sv, axis=1) loss_expand -= ib_alpha * (tf_l2norm(loss_expand, axis=-1, keep_dims=True)/tf_l2norm(sv_loss, axis=-1, keep_dims=True)) * sv_loss grads_expand, global_norm_expand = grad_clip(loss_expand, max_grad_norm, ['model/worker_module', 'model/ordinary_encoder']) trainer_expand = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon) _train_expand = trainer_expand.apply_gradients(grads_expand) # sv_loss_grads, sv_global_norm = grad_clip(sv_loss, max_grad_norm, ['model/worker_module/comm', 'model/worker_module/w_value']) # sv_trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon) # _sv_train = sv_trainer.apply_gradients(sv_loss_grads) elif algo == 'anchor': neglogpac = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.wpi, labels=A) pg_loss = tf.reduce_mean(ADV * neglogpac) entropy = tf.reduce_mean(cat_entropy(train_model.wpi)) vf_loss = tf.reduce_mean(mse(tf.squeeze(train_model.wvf), R)) # anchor method param_list = [] for scope in ['model/worker_module']: List = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope) print(len(List)) param_list += List param_value_layer1_w = param_list[0] param_value_layer1_b = param_list[1] param_policy_layer2_w = param_list[2] param_value_layer2_w = param_list[4] param_value_layer2_b = param_list[5] init_stddev = 5.0 # 7.0 init_stddev_2 = 0.18 / np.sqrt(cell) # normal scaling lambda_anchor = [0.000001,0.1] layer1_w_init = tf.random_normal(mean=0., stddev=init_stddev, shape=param_value_layer1_w.get_shape()) layer1_b_init = tf.random_normal(mean=0., stddev=init_stddev, shape=param_value_layer1_b.get_shape()) layer2_w_init = tf.random_normal(mean=0, stddev=init_stddev_2, shape=param_value_layer2_w.get_shape()) layer2_b_init = tf.random_normal(mean=0, stddev=init_stddev_2, shape=param_value_layer2_b.get_shape()) loss_anchor = lambda_anchor[0] / nbatch * tf.reduce_sum(tf.square(layer1_w_init - param_value_layer1_w)) loss_anchor += lambda_anchor[0] / nbatch * tf.reduce_sum(tf.square(layer1_b_init - param_value_layer1_b)) loss_anchor += lambda_anchor[1] / nbatch * tf.reduce_sum(tf.square(layer2_w_init - param_value_layer2_w)) loss_anchor += lambda_anchor[1] / nbatch * tf.reduce_sum(tf.square(layer2_b_init - param_value_layer2_b)) loss = pg_loss + vf_coef * vf_loss - ent_coef * entropy + loss_anchor grads, global_norm = grad_clip(loss, max_grad_norm, ['model/worker_module', 'model/ordinary_encoder']) trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon) _train = trainer.apply_gradients(grads) else: # regular algorithm neglogpac = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.wpi, labels=A) pg_loss = tf.reduce_mean(ADV * neglogpac) entropy = tf.reduce_mean(cat_entropy(train_model.wpi)) vf_loss = tf.reduce_mean(mse(tf.squeeze(train_model.wvf), R)) loss = pg_loss + vf_coef * vf_loss - ent_coef * entropy grads, global_norm = grad_clip(loss, max_grad_norm, ['model/worker_module', 'model/ordinary_encoder']) trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon) _train = trainer.apply_gradients(grads) params = find_trainable_variables("model") lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule) def train(wobs, whs, states, rewards, masks, actions, values, noises): advs = rewards - values # adv_mu, adv_var = np.mean(advs), np.var(advs)+1e-3 # advs = (advs - adv_mu) / adv_var for step in range(len(whs)): cur_lr = lr.value() sv_td_map = {train_model.wX : wobs, train_model.istraining:True, A:actions, R:rewards, LR:cur_lr} # Sess Graph # writer = tf.summary.FileWriter('./', sess.graph) repr_td_map = {train_model.wX: wobs, train_model.istraining: True, A: actions, R: rewards, LR: cur_lr} rl_td_map = {train_model.wX : wobs, train_model.istraining: True, A:actions, ADV:advs, R:rewards, LR:cur_lr} if states is not None: rl_td_map[train_model.wS] = states rl_td_map[train_model.wM] = masks repr_grad_norm = 0. # print(str(np.sum(whs-sess.run(train_model.wh, feed_dict={train_model.wX : wobs, train_model.istraining:True, train_model.noise:noises})))) if algo == 'use_svib_uniform' or algo == 'use_svib_gaussian': repr_td_map[train_model.noise_expand], repr_td_map[train_model.NOISE_KEEP] = sess.run(train_model.noise_expand), noises wh_expands, sv_gradients = sess.run([train_model.wh_expand, sv_grads], feed_dict=repr_td_map) rl_td_map[train_model.wh_expand] = wh_expands tloss, value_loss, policy_loss, policy_entropy, rl_grad_norm, _ = sess.run( [loss_expand_, vf_loss_expand_, pg_loss_expand_, entropy_expand_, global_norm_expand, _train_expand], feed_dict=rl_td_map ) repr_td_map[SV_GRADS] = sv_gradients # if algo == 'use_svib_gaussian': # gaussian_gradients, repr_grad_norm, __ =\ # sess.run([gaussian_grad, repr_global_norm, _repr_train], feed_dict=repr_td_map) # return tloss, value_loss, policy_loss, policy_entropy, rl_grad_norm, gaussian_gradients, repr_grad_norm # represnet_loss, SV_GRAD, EXPLOIT, LOG_P_GRADS, EXPLORE repr_grad_norm, represent_loss, __ = sess.run([repr_global_norm, repr_loss, _repr_train], feed_dict=repr_td_map) elif algo == 'anchor': rl_td_map[train_model.wX] = wobs tloss, value_loss, policy_loss, policy_entropy, rl_grad_norm, anchor_loss, _ = sess.run( [loss, vf_loss, pg_loss, entropy, global_norm, loss_anchor ,_train], feed_dict=rl_td_map ) represent_loss = 0. sv_loss_ = 0 elif algo == 'sv_a2c': sv_td_map[train_model.noise_expand], sv_td_map[train_model.NOISE_KEEP] = sess.run( train_model.noise_expand), noises wvf_expands, sv_gradients = sess.run([train_model.wvf_expand, sv_grads], feed_dict=sv_td_map) rl_td_map[train_model.wvf_expand] = wvf_expands rl_td_map[train_model.noise_expand], rl_td_map[train_model.NOISE_KEEP] = sess.run( train_model.noise_expand), noises rl_td_map[SV_GRADS] = sv_gradients tloss, value_loss, policy_loss, policy_entropy, rl_grad_norm, sv_loss_, _ = sess.run( [loss_expand_, vf_loss_expand_, pg_loss_expand_, entropy_expand_, global_norm_expand, sv_loss, _train_expand], feed_dict=rl_td_map ) sv_td_map[SV_GRADS] = sv_gradients anchor_loss = 0. represent_loss = 0. else: rl_td_map[train_model.wX], rl_td_map[train_model.noise] = wobs, noises#noise won't be used when algo is 'regular' tloss, value_loss, policy_loss, policy_entropy, rl_grad_norm, _ = sess.run( [loss, vf_loss, pg_loss, entropy, global_norm, _train], feed_dict=rl_td_map ) # repr_td_map[WH_GRADS] = wh_gradients # repr_grad_norm, __ = sess.run([ordin_repr_global_norm, _ordin_repr_train], feed_dict=repr_td_map) repr_grad_norm = 0. represent_loss = 0. anchor_loss = 0 sv_loss_ = 0 return tloss, value_loss, policy_loss, policy_entropy, rl_grad_norm, repr_grad_norm, represent_loss, anchor_loss, sv_loss_#SV_GRAD, EXPLOIT, LOG_P_GRADS, EXPLORE def train_mine(wobs, whs, steps=256, lr=7e-4): # whs_std = (whs-np.mean(whs,axis=0,keepdims=True))/(1e-8 + np.std(whs,axis=0,keepdims=True)) idx = np.arange(len(whs)) ___ = sess.run(reset_update_params) for i in range(int(steps)): np.random.shuffle(idx) mi, T_value, __ = sess.run([ib_loss, T, _t_train], feed_dict={train_model.wX: wobs[idx], train_model.wh: whs[idx], LR: lr, train_model.istraining: True}) logger.record_tabular('mutual_info_loss', float(mi)) logger.record_tabular('T_value', float(T_value)) logger.dump_tabular() def save(save_path): ps = sess.run(params) make_path(osp.dirname(save_path)) joblib.dump(ps, save_path) def load(load_path): loaded_params = joblib.load(load_path) restores = [] for p, loaded_p in zip(params, loaded_params): restores.append(p.assign(loaded_p)) ps = sess.run(restores) self.train = train self.train_mine = train_mine self.train_model = train_model self.step_model = step_model self.get_wh = step_model.get_wh self.get_noise = step_model.get_noise self.value = step_model.wvalue self.step = step_model.step self.initial_state = step_model.w_initial_state self.save = save self.load = load self.sv_M = sv_M # self.rpf_h = rpf_h # self.rpf_matrix = rpf_matrix # self.rpf_grads = rpf_grads tf.global_variables_initializer().run(session=sess)
def train(self, model_path, flag='train'): # load hyperparams hps = self.hps for iteration in range(hps.iters): # calculate current alpha if iteration + 1 < hps.lat_sched_iters and iteration >= hps.enc_pretrain_iters: current_alpha = hps.alpha_enc * ( iteration + 1 - hps.enc_pretrain_iters) / ( hps.lat_sched_iters - hps.enc_pretrain_iters) else: current_alpha = 0 if iteration >= hps.enc_pretrain_iters: n_latent_steps = hps.n_latent_steps \ if iteration > hps.enc_pretrain_iters else hps.dis_pretrain_iters for step in range(n_latent_steps): #===================== Train latent discriminator =====================# data = next(self.data_loader) (c_i, c_j), (x_i_t, x_i_tk, x_i_prime, x_j) = self.permute_data(data) # encode enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step( x_i_t, x_i_tk, x_i_prime, x_j) # latent discriminate latent_w_dis, latent_gp = self.latent_discriminate_step( enc_i_t, enc_i_tk, enc_i_prime, enc_j) lat_loss = -hps.alpha_dis * latent_w_dis + hps.lambda_ * latent_gp reset_grad([self.LatentDiscriminator]) lat_loss.backward() grad_clip([self.LatentDiscriminator], self.hps.max_grad_norm) self.lat_opt.step() # print info info = { f'{flag}/D_latent_w_dis': latent_w_dis.data[0], f'{flag}/latent_gp': latent_gp.data[0], } slot_value = (step, iteration + 1, hps.iters) + \ tuple([value for value in info.values()]) log = 'lat_D-%d:[%06d/%06d], w_dis=%.3f, gp=%.2f' print(log % slot_value) for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration) # two stage training if iteration >= hps.patch_start_iter: for step in range(hps.n_patch_steps): #===================== Train patch discriminator =====================# data = next(self.data_loader) (c_i, _), (x_i_t, _, _, _) = self.permute_data(data) # encode enc_i_t, = self.encode_step(x_i_t) c_sample = self.sample_c(x_i_t.size(0)) x_tilde = self.decode_step(enc_i_t, c_i) # Aux classify loss patch_w_dis, real_logits, fake_logits, patch_gp = \ self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=True) patch_loss = -hps.beta_dis * patch_w_dis + hps.lambda_ * patch_gp + hps.beta_clf * c_loss reset_grad([self.PatchDiscriminator]) patch_loss.backward() grad_clip([self.PatchDiscriminator], self.hps.max_grad_norm) self.patch_opt.step() # print info info = { f'{flag}/D_patch_w_dis': patch_w_dis.data[0], f'{flag}/patch_gp': patch_gp.data[0], f'{flag}/c_loss': c_loss.data[0], f'{flag}/real_acc': real_acc, f'{flag}/fake_acc': fake_acc, } slot_value = (step, iteration + 1, hps.iters) + \ tuple([value for value in info.values()]) log = 'patch_D-%d:[%06d/%06d], w_dis=%.3f, gp=%.2f, c_loss=%.3f, real_acc=%.2f, fake_acc=%.2f' print(log % slot_value) for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration) #===================== Train G =====================# data = next(self.data_loader) (c_i, c_j), (x_i_t, x_i_tk, x_i_prime, x_j) = self.permute_data(data) # encode enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step( x_i_t, x_i_tk, x_i_prime, x_j) # decode x_tilde = self.decode_step(enc_i_t, c_i) loss_rec = torch.mean(torch.abs(x_tilde - x_i_t)) # latent discriminate loss_adv = self.latent_discriminate_step(enc_i_t, enc_i_tk, enc_i_prime, enc_j, is_dis=False) ae_loss = loss_rec + current_alpha * loss_adv reset_grad([self.Encoder, self.Decoder]) retain_graph = True if hps.n_patch_steps > 0 else False ae_loss.backward(retain_graph=retain_graph) grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm) self.ae_opt.step() info = { f'{flag}/loss_rec': loss_rec.data[0], f'{flag}/loss_adv': loss_adv.data[0], f'{flag}/alpha': current_alpha, } slot_value = (iteration + 1, hps.iters) + tuple( [value for value in info.values()]) log = 'G:[%06d/%06d], loss_rec=%.2f, loss_adv=%.2f, alpha=%.2e' print(log % slot_value) for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) # patch discriminate if hps.n_patch_steps > 0 and iteration >= hps.patch_start_iter: c_sample = self.sample_c(x_i_t.size(0)) x_tilde = self.decode_step(enc_i_t, c_sample) patch_w_dis, real_logits, fake_logits = \ self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=False) patch_loss = hps.beta_dec * patch_w_dis + hps.beta_clf * c_loss reset_grad([self.Decoder]) patch_loss.backward() grad_clip([self.Decoder], self.hps.max_grad_norm) self.decoder_opt.step() info = { f'{flag}/G_patch_w_dis': patch_w_dis.data[0], f'{flag}/c_loss': c_loss.data[0], f'{flag}/real_acc': real_acc, f'{flag}/fake_acc': fake_acc, } slot_value = (iteration + 1, hps.iters) + tuple( [value for value in info.values()]) log = 'G:[%06d/%06d]: patch_w_dis=%.2f, c_loss=%.2f, real_acc=%.2f, fake_acc=%.2f' print(log % slot_value) for tag, value in info.items(): self.logger.scalar_summary(tag, value, iteration + 1) if iteration % 1000 == 0 or iteration + 1 == hps.iters: self.save_model(model_path, iteration)