예제 #1
0
def finetune_discrete_decoder(trainer, look_up, model_path, flag='train'):
    # trainer is already trained
    hyper_params = trainer.hps

    for iteration in range(hyper_params.enc_pretrain_iters):
        data = next(trainer.data_loader)
        c, x = trainer.permute_data(data)

        encoded = trainer.encode_step(x)
        x = look_up(x)
        x_tilde = trainer.decode_step(encoded, c)
        loss_rec = torch.mean(torch.abs(x_tilde - x))
        reset_grad([trainer.Encoder, trainer.Decoder])
        loss_rec.backward()
        grad_clip([trainer.Encoder, trainer.Decoder],
                  trainer.hps.max_grad_norm)
        trainer.ae_opt.step()

        # tb info
        info = {
            f'{flag}/disc_loss_rec': loss_rec.item(),
        }
        slot_value = (iteration + 1, hyper_params.enc_pretrain_iters) + \
            tuple([value for value in info.values()])
        log = 'train_discrete:[%06d/%06d], loss_rec=%.3f'
        print(log % slot_value, end='\r')

        if iteration % 100 == 0:
            for tag, value in info.items():
                trainer.logger.scalar_summary(tag, value, iteration + 1)
        if (iteration + 1) % 1000 == 0:
            trainer.save_model(model_path, 'dc', iteration + 1)
    print()
예제 #2
0
 def train(self, model_path, flag='train'):
     # load hyperparams
     hps = self.hps
     for iteration in range(hps.iters):
         data = next(self.data_loader)
         y, x = self.permute_data(data)
         # encode
         enc = self.encode_step(x)
         # forward to classifier
         logits = self.forward_step(enc)
         # calculate loss
         loss = self.cal_loss(logits, y)
         # optimize
         reset_grad([self.SpeakerClassifier])
         loss.backward()
         grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm)
         self.opt.step()
         # calculate acc
         acc = cal_acc(logits, y)
         # print info
         info = {
             f'{flag}/loss': loss.data[0], 
             f'{flag}/acc': acc,
         }
         slot_value = (iteration + 1, hps.iters) + tuple([value for value in info.values()])
         log = 'iter:[%06d/%06d], loss=%.3f, acc=%.3f'
         print(log % slot_value, end='\r')
         for tag, value in info.items():
             self.logger.scalar_summary(tag, value, iteration)
         if iteration % 1000 == 0 or iteration + 1 == hps.iters:
             valid_loss, valid_acc = self.valid(n_batches=10)
             # print info
             info = {
                 f'{flag}/valid_loss': valid_loss, 
                 f'{flag}/valid_acc': valid_acc,
             }
             slot_value = (iteration + 1, hps.iters) + \
                     tuple([value for value in info.values()])
             log = 'iter:[%06d/%06d], valid_loss=%.3f, valid_acc=%.3f'
             print(log % slot_value)
             for tag, value in info.items():
                 self.logger.scalar_summary(tag, value, iteration)
             self.save_model(model_path, iteration)
예제 #3
0
 def train(self, model_path, flag='train', mode='train'):
     # load hyperparams
     hps = self.hps
     if mode == 'pretrain_G':
         for iteration in range(hps.enc_pretrain_iters):
             data = next(self.data_loader)
             c, x = self.permute_data(data)
             # encode
             enc = self.encode_step(x)
             x_tilde = self.decode_step(enc, c)
             loss_rec = torch.mean(torch.abs(x_tilde - x))
             reset_grad([self.Encoder, self.Decoder])
             loss_rec.backward()
             grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
             self.ae_opt.step()
             # tb info
             info = {
                 f'{flag}/pre_loss_rec': loss_rec.item(),
             }
             slot_value = (iteration + 1, hps.enc_pretrain_iters) + tuple([value for value in info.values()])
             log = 'pre_G:[%06d/%06d], loss_rec=%.3f'
             print(log % slot_value)
             if iteration % 100 == 0:
                 for tag, value in info.items():
                     self.logger.scalar_summary(tag, value, iteration + 1)
     elif mode == 'pretrain_D':
         for iteration in range(hps.dis_pretrain_iters):
             data = next(self.data_loader)
             c, x = self.permute_data(data)
             # encode
             enc = self.encode_step(x)
             # classify speaker
             logits = self.clf_step(enc)
             loss_clf = self.cal_loss(logits, c)
             # update 
             reset_grad([self.SpeakerClassifier])
             loss_clf.backward()
             grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm)
             self.clf_opt.step()
             # calculate acc
             acc = cal_acc(logits, c)
             info = {
                 f'{flag}/pre_loss_clf': loss_clf.item(),
                 f'{flag}/pre_acc': acc,
             }
             slot_value = (iteration + 1, hps.dis_pretrain_iters) + tuple([value for value in info.values()])
             log = 'pre_D:[%06d/%06d], loss_clf=%.2f, acc=%.2f'
             print(log % slot_value)
             if iteration % 100 == 0:
                 for tag, value in info.items():
                     self.logger.scalar_summary(tag, value, iteration + 1)
     elif mode == 'patchGAN':
         for iteration in range(hps.patch_iters):
             #=======train D=========#
             for step in range(hps.n_patch_steps):
                 data = next(self.data_loader)
                 c, x = self.permute_data(data)
                 ## encode
                 enc = self.encode_step(x)
                 # sample c
                 c_prime = self.sample_c(x.size(0))
                 # generator
                 x_tilde = self.gen_step(enc, c_prime)
                 # discriminstor
                 w_dis, real_logits, gp = self.patch_step(x, x_tilde, is_dis=True)
                 # aux classification loss 
                 loss_clf = self.cal_loss(real_logits, c)
                 loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp
                 reset_grad([self.PatchDiscriminator])
                 loss.backward()
                 grad_clip([self.PatchDiscriminator], self.hps.max_grad_norm)
                 self.patch_opt.step()
                 # calculate acc
                 acc = cal_acc(real_logits, c)
                 info = {
                     f'{flag}/w_dis': w_dis.item(),
                     f'{flag}/gp': gp.item(), 
                     f'{flag}/real_loss_clf': loss_clf.item(),
                     f'{flag}/real_acc': acc, 
                 }
                 slot_value = (step, iteration+1, hps.patch_iters) + tuple([value for value in info.values()])
                 log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f'
                 print(log % slot_value)
                 if iteration % 100 == 0:
                     for tag, value in info.items():
                         self.logger.scalar_summary(tag, value, iteration + 1)
             #=======train G=========#
             data = next(self.data_loader)
             c, x = self.permute_data(data)
             # encode
             enc = self.encode_step(x)
             # sample c
             c_prime = self.sample_c(x.size(0))
             # generator
             x_tilde = self.gen_step(enc, c_prime)
             # discriminstor
             loss_adv, fake_logits = self.patch_step(x, x_tilde, is_dis=False)
             # aux classification loss 
             loss_clf = self.cal_loss(fake_logits, c_prime)
             loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv
             reset_grad([self.Generator])
             loss.backward()
             grad_clip([self.Generator], self.hps.max_grad_norm)
             self.gen_opt.step()
             # calculate acc
             acc = cal_acc(fake_logits, c_prime)
             info = {
                 f'{flag}/loss_adv': loss_adv.item(),
                 f'{flag}/fake_loss_clf': loss_clf.item(),
                 f'{flag}/fake_acc': acc, 
             }
             slot_value = (iteration+1, hps.patch_iters) + tuple([value for value in info.values()])
             log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f'
             print(log % slot_value)
             if iteration % 100 == 0:
                 for tag, value in info.items():
예제 #4
0
 else:
     current_alpha = hps.alpha_enc
 #==================train D==================#
 for step in range(hps.n_latent_steps):
     data = next(self.data_loader)
     c, x = self.permute_data(data)
     # encode
     enc = self.encode_step(x)
     # classify speaker
     logits = self.clf_step(enc)
     loss_clf = self.cal_loss(logits, c)
     loss = hps.alpha_dis * loss_clf
     # update 
     reset_grad([self.SpeakerClassifier])
     loss.backward()
     grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm)
     self.clf_opt.step()
     # calculate acc
     acc = cal_acc(logits, c)
     info = {
         f'{flag}/D_loss_clf': loss_clf.item(),
         f'{flag}/D_acc': acc,
     }
     slot_value = (step, iteration + 1, hps.iters) + tuple([value for value in info.values()])
     log = 'D-%d:[%06d/%06d], loss_clf=%.2f, acc=%.2f'
     print(log % slot_value)
     if iteration % 100 == 0:
         for tag, value in info.items():
             self.logger.scalar_summary(tag, value, iteration + 1)
 #==================train G==================#
 data = next(self.data_loader)
예제 #5
0
    def train(self, model_path, flag='train', mode='train'):
        if not os.path.isdir(model_path):
            os.makedirs(model_path)
            os.chmod(model_path, 0o755)
        model_path = os.path.join(model_path, 'model.pkl')

        # load hyperparams
        hps = self.hps
        if mode == 'pretrain_G':
            for iteration in range(2200):
                data = next(self.data_loader)
                c, x = self.permute_data(data)
                # encode
                enc = self.encode_step(x)
                x_tilde = self.decode_step(enc, c)
                loss_rec = torch.mean(torch.abs(x_tilde - x))
                reset_grad([self.Encoder, self.Decoder])
                loss_rec.backward()
                grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
                self.ae_opt.step()
                # tb info
                info = {
                    f'{flag}/pre_loss_rec': loss_rec.item(),
                }
                slot_value = (iteration + 1, 2200) + tuple(
                    [value for value in info.values()])
                log = 'pre_G:[%06d/%06d], loss_rec=%.3f'
                print(log % slot_value)
                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
        elif mode == 'pretrain_D':
            for iteration in range(2200):
                data = next(self.data_loader)
                c, x = self.permute_data(data)
                # encode
                enc = self.encode_step(x)
                # classify speaker
                logits = self.clf_step(enc)
                loss_clf = self.cal_loss(logits, c)
                # update
                reset_grad([self.SpeakerClassifier])
                loss_clf.backward()
                grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm)
                self.clf_opt.step()
                # calculate acc
                acc = cal_acc(logits, c)
                info = {
                    f'{flag}/pre_loss_clf': loss_clf.item(),
                    f'{flag}/pre_acc': acc,
                }
                slot_value = (iteration + 1, 2200) + tuple(
                    [value for value in info.values()])
                log = 'pre_D:[%06d/%06d], loss_clf=%.2f, acc=%.2f'
                print(log % slot_value)
                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
        elif mode == 'patchGAN':
            for iteration in range(1100):
                #=======train D=========#
                for step in range(hps.n_patch_steps):
                    data = next(self.data_loader)
                    c, x = self.permute_data(data)
                    ## encode
                    enc = self.encode_step(x)
                    # sample c
                    c_prime = self.sample_c(x.size(0))
                    # generator
                    x_tilde = self.gen_step(enc, c_prime)
                    # discriminstor
                    w_dis, real_logits, gp = self.patch_step(x,
                                                             x_tilde,
                                                             is_dis=True)
                    # aux classification loss
                    loss_clf = self.cal_loss(real_logits, c)
                    loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp
                    reset_grad([self.PatchDiscriminator])
                    loss.backward()
                    grad_clip([self.PatchDiscriminator],
                              self.hps.max_grad_norm)
                    self.patch_opt.step()
                    # calculate acc
                    acc = cal_acc(real_logits, c)
                    info = {
                        f'{flag}/w_dis': w_dis.item(),
                        f'{flag}/gp': gp.item(),
                        f'{flag}/real_loss_clf': loss_clf.item(),
                        f'{flag}/real_acc': acc,
                    }
                    slot_value = (step, iteration + 1, 1100) + tuple(
                        [value for value in info.values()])
                    log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f'
                    print(log % slot_value)
                    if iteration % 100 == 0:
                        for tag, value in info.items():
                            self.logger.scalar_summary(tag, value,
                                                       iteration + 1)
                #=======train G=========#
                data = next(self.data_loader)
                c, x = self.permute_data(data)
                # encode
                enc = self.encode_step(x)
                # sample c
                c_prime = self.sample_c(x.size(0))
                # generator
                x_tilde = self.gen_step(enc, c_prime)
                # discriminstor
                loss_adv, fake_logits = self.patch_step(x,
                                                        x_tilde,
                                                        is_dis=False)
                # aux classification loss
                loss_clf = self.cal_loss(fake_logits, c_prime)
                loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv
                reset_grad([self.Generator])
                loss.backward()
                grad_clip([self.Generator], self.hps.max_grad_norm)
                self.gen_opt.step()
                # calculate acc
                acc = cal_acc(fake_logits, c_prime)
                info = {
                    f'{flag}/loss_adv': loss_adv.item(),
                    f'{flag}/fake_loss_clf': loss_clf.item(),
                    f'{flag}/fake_acc': acc,
                }
                slot_value = (iteration + 1, 1100) + tuple(
                    [value for value in info.values()])
                log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f'
                print(log % slot_value)
                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if iteration % 1000 == 0 or iteration + 1 == hps.patch_iters:
                    self.save_model(model_path, iteration + hps.iters)
        elif mode == 'train':
            for iteration in range(1100):
                # calculate current alpha
                if iteration < hps.lat_sched_iters:
                    current_alpha = hps.alpha_enc * (iteration /
                                                     hps.lat_sched_iters)
                else:
                    current_alpha = hps.alpha_enc
                #==================train D==================#
                for step in range(hps.n_latent_steps):
                    data = next(self.data_loader)
                    c, x = self.permute_data(data)
                    # encode
                    enc = self.encode_step(x)
                    # classify speaker
                    logits = self.clf_step(enc)
                    loss_clf = self.cal_loss(logits, c)
                    loss = hps.alpha_dis * loss_clf
                    # update
                    reset_grad([self.SpeakerClassifier])
                    loss.backward()
                    grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm)
                    self.clf_opt.step()
                    # calculate acc
                    acc = cal_acc(logits, c)
                    info = {
                        f'{flag}/D_loss_clf': loss_clf.item(),
                        f'{flag}/D_acc': acc,
                    }
                    slot_value = (step, iteration + 1, 1100) + tuple(
                        [value for value in info.values()])
                    log = 'D-%d:[%06d/%06d], loss_clf=%.2f, acc=%.2f'
                    print(log % slot_value)
                    if iteration % 100 == 0:
                        for tag, value in info.items():
                            self.logger.scalar_summary(tag, value,
                                                       iteration + 1)
                #==================train G==================#
                data = next(self.data_loader)
                c, x = self.permute_data(data)
                # encode
                enc = self.encode_step(x)
                # decode
                x_tilde = self.decode_step(enc, c)
                loss_rec = torch.mean(torch.abs(x_tilde - x))
                # classify speaker
                logits = self.clf_step(enc)
                acc = cal_acc(logits, c)
                loss_clf = self.cal_loss(logits, c)
                # maximize classification loss
                loss = loss_rec - current_alpha * loss_clf
                reset_grad([self.Encoder, self.Decoder])
                loss.backward()
                grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
                self.ae_opt.step()
                info = {
                    f'{flag}/loss_rec': loss_rec.item(),
                    f'{flag}/G_loss_clf': loss_clf.item(),
                    f'{flag}/alpha': current_alpha,
                    f'{flag}/G_acc': acc,
                }
                slot_value = (iteration + 1, 1100) + tuple(
                    [value for value in info.values()])
                log = 'G:[%06d/%06d], loss_rec=%.3f, loss_clf=%.2f, alpha=%.2e, acc=%.2f'
                print(log % slot_value)
                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if iteration % 1000 == 0 or iteration + 1 == hps.iters:
                    self.save_model(model_path, iteration)
    def __init__(self,
                 *,
                 policy,
                 ob_space,
                 ac_space,
                 nbatch_act,
                 nbatch_train,
                 nsteps,
                 ent_coef,
                 vf_coef,
                 max_grad_norm,
                 cell=256,
                 sv_M=32,
                 algo='regular',
                 ib_alpha=1e-3):
        sess = tf_util.make_session()

        act_model = policy(sess,
                           ob_space,
                           ac_space,
                           nbatch_act,
                           1,
                           1,
                           cell=cell,
                           M=sv_M,
                           model='step_model',
                           algo=algo)
        train_model = policy(sess,
                             ob_space,
                             ac_space,
                             nbatch_train,
                             1,
                             nsteps,
                             cell=cell,
                             M=sv_M,
                             model='train_model',
                             algo=algo)

        A = train_model.wpdtype.sample_placeholder([None])
        ADV = tf.placeholder(tf.float32, [None])
        R = tf.placeholder(tf.float32, [None])
        OLDNEGLOGPAC = tf.placeholder(tf.float32, [None])
        OLDNEGLOGPAC_expand = tf.placeholder(tf.float32, [None, sv_M])
        OLDVPRED = tf.placeholder(tf.float32, [None])
        OLDVPRED_expand = tf.placeholder(tf.float32, [None, sv_M])
        LR = tf.placeholder(tf.float32, [])
        CLIPRANGE = tf.placeholder(tf.float32, [])

        if algo == 'use_svib_uniform' or algo == 'use_svib_gaussian':

            def expand_placeholder(X, M=sv_M):
                return tf.tile(tf.expand_dims(X, axis=-1), [1, M])

            A_expand, R_expand = expand_placeholder(A), expand_placeholder(R)
            neglogpac_expand = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=train_model.wpi_expand,
                labels=A_expand)  #shape=[nbatch, sv_M]
            entropy_expand = tf.reduce_mean(cat_entropy(
                train_model.wpi_expand),
                                            axis=-1)  #shape=[nbatch]
            vpred_expand = train_model.wvf_expand[:, :, 0]
            vpredclipped_expand = OLDVPRED_expand + tf.clip_by_value(
                train_model.wvf_expand[:, :, 0] - OLDVPRED_expand, -CLIPRANGE,
                CLIPRANGE)
            vf_loss1_expand = tf.square(vpred_expand - R_expand)
            vf_loss2_expand = tf.square(vpredclipped_expand - R_expand)
            vf_loss_expand = .5 * tf.reduce_mean(tf.maximum(
                vf_loss1_expand, vf_loss2_expand),
                                                 axis=-1)  #shape = [nbatch]
            ratio_expand = tf.exp(OLDNEGLOGPAC_expand - neglogpac_expand)
            ADV_expand = R_expand - OLDVPRED_expand
            # ADV_expand_mean, ADV_expand_var = tf.nn.moments(ADV_expand, axes=0, keep_dims=True)#shape = [1,sv_M]
            ADV_expand_mean, ADV_expand_var = tf.nn.moments(
                ADV_expand, axes=[0, 1])  #shape = [1,sv_M]
            ADV_expand_normal = (ADV_expand - ADV_expand_mean) / (
                tf.sqrt(ADV_expand_var) + 1e-8)
            pg_losses_expand = -ADV_expand_normal * ratio_expand
            pg_losses2_expand = -ADV_expand_normal * tf.clip_by_value(
                ratio_expand, 1. - CLIPRANGE, 1. + CLIPRANGE)
            pg_loss_expand = tf.reduce_mean(tf.maximum(pg_losses_expand,
                                                       pg_losses2_expand),
                                            axis=-1)
            J_theta = -(pg_loss_expand + vf_coef * vf_loss_expand -
                        ent_coef * entropy_expand)

            loss_expand = -J_theta / float(nbatch_train)
            pg_loss_expand_ = tf.reduce_mean(pg_loss_expand)
            vf_loss_expand_ = tf.reduce_mean(vf_loss_expand)
            entropy_expand_ = tf.reduce_mean(entropy_expand)

            log_p_grads = tf.gradients(
                J_theta / np.sqrt(ib_alpha),
                [train_model.wh_expand])[0]  #shape=[nbatch, sv_M, cell]
            if algo == 'use_svib_gaussian':
                mean, var = tf.nn.moments(
                    train_model.wh_expand, axes=1,
                    keep_dims=True)  #shape=[nbatch, 1,cell]
                gaussian_grad = -(train_model.wh_expand - mean) / (
                    float(sv_M) * (var + 1e-3))
                log_p_grads += 5e-3 * (
                    tf_l2norm(log_p_grads, axis=-1, keep_dims=True) /
                    tf_l2norm(gaussian_grad, axis=-1,
                              keep_dims=True)) * gaussian_grad
            sv_grads = tf.constant(0.,
                                   tf.float32,
                                   shape=[nbatch_train, 0, cell])
            exploit_total_norm_square = 0
            explore_total_norm_square = 0
            explore_coef = 1.
            if env_name == 'SeaquestNoFrameskip-v4':
                explore_coef = 0.01
            elif env_name in [
                    'AirRaidNoFrameskip-v4,'
                    'BreakoutNoFrameskip-v4', 'AtlantisNoFrameskip-v4',
                    'StarGunnerNoFrameskip-v4', 'AsteroidsNoFrameskip-v4',
                    'YarsRevengeNoFrameskip-v4'
            ]:
                explore_coef = 0.
            print('env_name:', env_name, 'explore_coef: ', explore_coef)
            for i in range(sv_M):
                exploit = tf.reduce_sum(train_model.rpf_matrix[:, :, i:i + 1] *
                                        log_p_grads,
                                        axis=1)
                explore = np.sqrt(
                    ib_alpha) * explore_coef * train_model.rpf_grads[:, i, :]
                exploit_total_norm_square += tf.square(
                    tf_l2norm(exploit, axis=-1, keep_dims=False))
                explore_total_norm_square += tf.square(
                    tf_l2norm(explore, axis=-1, keep_dims=False))
                sv_grad = exploit + explore  #shape=[nbatch, cell]
                sv_grads = tf.concat(
                    [sv_grads, tf.expand_dims(sv_grad, axis=1)], axis=1)
            SV_GRADS = tf.placeholder(tf.float32, [nbatch_train, sv_M, cell])
            repr_loss = tf.reduce_mean(SV_GRADS * train_model.wh_expand,
                                       axis=1)  #shape=[nbatch,cell]
            repr_loss = -tf.reduce_mean(tf.reduce_sum(
                repr_loss,
                axis=-1))  #max optimization problem to minimization problem

            #op for debugging and visualization
            exploit_explore_ratio = tf.sqrt(
                exploit_total_norm_square /
                tf.maximum(explore_total_norm_square, 0.01))[0]
            # rpf_mat = tf.expand_dims(train_model.rpf_matrix, axis=-1)
            # log_p_grads_tile = tf.tile(tf.expand_dims(log_p_grads, axis=2), [1,1,sv_M,1])
            # exploit = tf.reduce_sum(rpf_mat*log_p_grads_tile, axis=1)
            # explore = np.sqrt(ib_alpha) * train_model.rpf_grads
            # sv_grads = exploit + explore
            # ind = 1
            # exploit = tf.reduce_sum(train_model.rpf_matrix[:, :, i:i + 1] * log_p_grads, axis=1)
            # explore = train_model.rpf_grads[:, i, :]
            # clip_coef = tf_l2norm(exploit, axis=-1, keep_dims=True)
            # explore_norm = tf_l2norm(explore, axis=-1, keep_dims=True)
            # explore = explore * 1e-2 * clip_coef / tf.maximum(explore_norm, clip_coef)
            # sv_grad = exploit + np.sqrt(ib_alpha) * explore  # shape=[nbatch, cell]

            grads_expand, grad_norm_expand = grad_clip(loss_expand,
                                                       max_grad_norm,
                                                       ['model/worker_module'])
            trainer_expand = tf.train.AdamOptimizer(learning_rate=LR,
                                                    epsilon=1e-5)
            _train_expand = trainer_expand.apply_gradients(grads_expand)
            repr_grads, repr_global_norm = grad_clip(
                repr_loss, max_grad_norm, ['model/ordinary_encoder'])
            repr_trainer = tf.train.AdamOptimizer(learning_rate=LR,
                                                  epsilon=1e-5)
            _repr_train = repr_trainer.apply_gradients(repr_grads)
        else:
            print('env_name:', env_name)
            neglogpac = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=train_model.wpi, labels=A)
            entropy = tf.reduce_mean(cat_entropy(train_model.wpi))
            vpred = train_model.wvf[:, 0]
            vpredclipped = OLDVPRED + tf.clip_by_value(
                train_model.wvf[:, 0] - OLDVPRED, -CLIPRANGE, CLIPRANGE)
            vf_losses1 = tf.square(vpred - R)
            vf_losses2 = tf.square(vpredclipped - R)
            vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))
            ratio = tf.exp(OLDNEGLOGPAC - neglogpac)
            pg_losses = -ADV * ratio
            pg_losses2 = -ADV * tf.clip_by_value(ratio, 1.0 - CLIPRANGE,
                                                 1.0 + CLIPRANGE)
            pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2))
            loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef

            grads, grad_norm = grad_clip(
                loss, max_grad_norm,
                ['model/worker_module', 'model/ordinary_encoder'])
            trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5)
            _train = trainer.apply_gradients(grads)

        with tf.variable_scope('model'):
            params = tf.trainable_variables()

        def generate_old_expand_data(obs, noises, masks, actions, states=None):
            noises_expand = sess.run(train_model.noise_expand)
            repr_td_map = {
                train_model.wX: obs,
                train_model.istraining: False,
                A: actions,
                train_model.noise_expand: noises_expand,
                train_model.NOISE_KEEP: noises
            }
            if states is not None:
                repr_td_map[train_model.wS] = states
                repr_td_map[train_model.wM] = masks
            neglogpacs_expand, vpreds_expand = \
                sess.run([neglogpac_expand, vpred_expand], feed_dict=repr_td_map)
            shape = noises_expand.shape
            noises_expand = noises_expand.reshape(nbatch_train, sv_M - 1,
                                                  *shape[1:])
            return [noises_expand, neglogpacs_expand, vpreds_expand]

        def train(lr,
                  cliprange,
                  obs,
                  noises,
                  returns,
                  masks,
                  actions,
                  values,
                  neglogpacs,
                  noises_expand=None,
                  neglogpacs_expand=None,
                  vpreds_expand=None,
                  states=None):
            advs = returns - values
            advs = (advs - advs.mean()) / (advs.std() + 1e-8)
            if algo == 'use_svib_uniform' or algo == 'use_svib_gaussian':
                shape = noises_expand.shape
                noises_expand_ = noises_expand.reshape(
                    nbatch_train * (sv_M - 1), *shape[2:])
                # print(noises_expand_.shape)
                repr_td_map = {
                    train_model.wX: obs,
                    train_model.istraining: True,
                    A: actions,
                    R: returns,
                    LR: lr,
                    CLIPRANGE: cliprange,
                    train_model.noise_expand: noises_expand_,
                    train_model.NOISE_KEEP: noises,
                    OLDNEGLOGPAC_expand: neglogpacs_expand,
                    OLDVPRED_expand: vpreds_expand
                }
            rl_td_map = {
                train_model.istraining: True,
                A: actions,
                R: returns,
                LR: lr,
                CLIPRANGE: cliprange
            }
            if states is not None:
                if algo == 'use_svib_uniform' or algo == 'use_svib_gaussian':
                    repr_td_map[train_model.wS] = states
                    repr_td_map[train_model.wM] = masks
                rl_td_map[train_model.wS] = states
                rl_td_map[train_model.wM] = masks

            if algo == 'use_svib_uniform' or algo == 'use_svib_gaussian':
                sv_gradients, whs_expand, ir_ratio = sess.run(
                    [sv_grads, train_model.wh_expand, exploit_explore_ratio],
                    feed_dict=repr_td_map)
                rl_td_map[OLDNEGLOGPAC_expand], rl_td_map[
                    OLDVPRED_expand], rl_td_map[
                        train_model.
                        wh_expand] = neglogpacs_expand, vpreds_expand, whs_expand
                value_loss, policy_loss, policy_entropy, _, rl_grad_norm = sess.run(
                    [
                        vf_loss_expand_, pg_loss_expand_, entropy_expand_,
                        _train_expand, grad_norm_expand
                    ],
                    feed_dict=rl_td_map)
                repr_td_map[SV_GRADS] = sv_gradients
                repr_grad_norm, represent_loss, __ = sess.run(
                    [repr_global_norm, repr_loss, _repr_train],
                    feed_dict=repr_td_map)
            else:
                rl_td_map[train_model.wX], rl_td_map[
                    train_model.
                    noise] = obs, noises  #noise won't be used when algo is 'regular'
                rl_td_map[OLDNEGLOGPAC], rl_td_map[OLDVPRED], rl_td_map[
                    ADV] = neglogpacs, values, advs
                value_loss, policy_loss, policy_entropy, _, rl_grad_norm = sess.run(
                    [vf_loss, pg_loss, entropy, _train, grad_norm],
                    feed_dict=rl_td_map)
                represent_loss, rpf_norm_, rpf_grad_norm_, sv_gradients, ir_ratio, repr_grad_norm = 0., 0., 0., 0., 0, 0.
            return policy_loss, value_loss, policy_entropy, represent_loss, ir_ratio, rl_grad_norm, repr_grad_norm

        self.loss_names = [
            'policy_loss', 'value_loss', 'policy_entropy', 'represent_loss',
            'exploit_explore_ratio', 'rl_grad_norm', 'repr_grad_norm'
        ]

        def save(save_path):
            ps = sess.run(params)
            make_path(osp.dirname(save_path))
            joblib.dump(ps, save_path)

        def load(load_path):
            loaded_params = joblib.load(load_path)
            restores = []
            for p, loaded_p in zip(params, loaded_params):
                restores.append(p.assign(loaded_p))
            sess.run(restores)
            # If you want to load weights, also save/load observation scaling inside VecNormalize

        self.generate_old_expand_data = generate_old_expand_data
        self.train = train
        self.train_model = train_model
        self.act_model = act_model
        self.step = act_model.step
        self.value = act_model.wvalue
        self.initial_state = act_model.w_initial_state
        self.save = save
        self.load = load
        tf.global_variables_initializer().run(session=sess)  #pylint: disable=E1101
예제 #7
0
 def train(self, model_path, flag='train', mode='train'):
     # load hyperparams
     hps = self.hps
     if mode == 'pretrain_G':
         for iteration in range(hps.enc_pretrain_iters):
             data = next(self.data_loader)
             c, x = self.permute_data(data)
             # encode
             enc = self.encode_step(x)
             x_tilde = self.decode_step(enc, c)
             loss_rec = torch.mean(torch.abs(x_tilde - x))
             reset_grad([self.Encoder, self.Decoder])
             loss_rec.backward()
             grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
             self.ae_opt.step()
             # tb info
             info = {
                 f'{flag}/pre_loss_rec': loss_rec.item(),
             }
             slot_value = (iteration + 1, hps.enc_pretrain_iters) + tuple(
                 [value for value in info.values()])
             log = 'pre_G:[%06d/%06d], loss_rec=%.3f'
             print(log % slot_value)
             if iteration % 100 == 0:
                 for tag, value in info.items():
                     self.logger.scalar_summary(tag, value, iteration + 1)
     elif mode == 'pretrain_D':
         for iteration in range(hps.dis_pretrain_iters):
             data = next(self.data_loader)
             c, x = self.permute_data(data)
             # encode
             enc = self.encode_step(x)
             # classify speaker
             logits = self.clf_step(enc)
             loss_clf = self.cal_loss(logits, c)
             # update
             reset_grad([self.SpeakerClassifier])
             loss_clf.backward()
             grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm)
             self.clf_opt.step()
             # calculate acc
             acc = cal_acc(logits, c)
             info = {
                 f'{flag}/pre_loss_clf': loss_clf.item(),
                 f'{flag}/pre_acc': acc,
             }
             slot_value = (iteration + 1, hps.dis_pretrain_iters) + tuple(
                 [value for value in info.values()])
             log = 'pre_D:[%06d/%06d], loss_clf=%.2f, acc=%.2f'
             print(log % slot_value)
             if iteration % 100 == 0:
                 for tag, value in info.items():
                     self.logger.scalar_summary(tag, value, iteration + 1)
     elif mode == 'patchGAN':
         for iteration in range(hps.patch_iters):
             #=======train D=========#
             for step in range(hps.n_patch_steps):
                 data = next(self.data_loader)
                 c, x = self.permute_data(data)
                 ## encode
                 enc = self.encode_step(x)
                 # sample c
                 c_prime = self.sample_c(x.size(0))
                 # generator
                 x_tilde = self.gen_step(enc, c_prime)
                 # discriminstor
                 w_dis, real_logits, gp = self.patch_step(x,
                                                          x_tilde,
                                                          is_dis=True)
                 # aux classification loss
                 loss_clf = self.cal_loss(real_logits, c)
                 loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp
                 reset_grad([self.PatchDiscriminator])
                 loss.backward()
                 grad_clip([self.PatchDiscriminator],
                           self.hps.max_grad_norm)
                 self.patch_opt.step()
                 # calculate acc
                 acc = cal_acc(real_logits, c)
                 info = {
                     f'{flag}/w_dis': w_dis.item(),
                     f'{flag}/gp': gp.item(),
                     f'{flag}/real_loss_clf': loss_clf.item(),
                     f'{flag}/real_acc': acc,
                 }
                 slot_value = (step, iteration + 1,
                               hps.patch_iters) + tuple(
                                   [value for value in info.values()])
                 log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f'
                 print(log % slot_value)
                 if iteration % 100 == 0:
                     for tag, value in info.items():
                         self.logger.scalar_summary(tag, value,
                                                    iteration + 1)
             #=======train G=========#
             data = next(self.data_loader)
             c, x = self.permute_data(data)
             # encode
             enc = self.encode_step(x)
             # sample c
             c_prime = self.sample_c(x.size(0))
             # generator
             x_tilde = self.gen_step(enc, c_prime)
             # discriminstor
             loss_adv, fake_logits = self.patch_step(x,
                                                     x_tilde,
                                                     is_dis=False)
             # aux classification loss
             loss_clf = self.cal_loss(fake_logits, c_prime)
             loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv
             reset_grad([self.Generator])
             loss.backward()
             grad_clip([self.Generator], self.hps.max_grad_norm)
             self.gen_opt.step()
             # calculate acc
             acc = cal_acc(fake_logits, c_prime)
             info = {
                 f'{flag}/loss_adv': loss_adv.item(),
                 f'{flag}/fake_loss_clf': loss_clf.item(),
                 f'{flag}/fake_acc': acc,
             }
             slot_value = (iteration + 1, hps.patch_iters) + tuple(
                 [value for value in info.values()])
             log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f'
             print(log % slot_value)
             if iteration % 100 == 0:
                 for tag, value in info.items():
                     self.logger.scalar_summary(tag, value, iteration)
         #===================== Train G =====================#
         data = next(self.data_loader)
         (c_i, c_j), (x_i_t, x_i_tk, x_i_prime,
                      x_j) = self.permute_data(data)
         # encode
         enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step(
             x_i_t, x_i_tk, x_i_prime, x_j)
         # decode
         x_tilde = self.decode_step(enc_i_t, c_i)
         loss_rec = torch.mean(torch.abs(x_tilde - x_i_t))
         # latent discriminate
         loss_adv = self.latent_discriminate_step(enc_i_t,
                                                  enc_i_tk,
                                                  enc_i_prime,
                                                  enc_j,
                                                  is_dis=False)
         ae_loss = loss_rec + current_alpha * loss_adv
         reset_grad([self.Encoder, self.Decoder])
         retain_graph = True if hps.n_patch_steps > 0 else False
         ae_loss.backward(retain_graph=retain_graph)
         grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
         self.ae_opt.step()
         info = {
             f'{flag}/loss_rec': loss_rec.data[0],
             f'{flag}/loss_adv': loss_adv.data[0],
             f'{flag}/alpha': current_alpha,
         }
         slot_value = (iteration + 1, hps.iters) + tuple(
             [value for value in info.values()])
         log = 'G:[%06d/%06d], loss_rec=%.2f, loss_adv=%.2f, alpha=%.2e'
         print(log % slot_value)
         for tag, value in info.items():
             self.logger.scalar_summary(tag, value, iteration + 1)
         # patch discriminate
         if hps.n_patch_steps > 0 and iteration >= hps.patch_start_iter:
             c_sample = self.sample_c(x_i_t.size(0))
             x_tilde = self.decode_step(enc_i_t, c_sample)
             patch_w_dis, real_logits, fake_logits = \
                     self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=False)
             patch_loss = hps.beta_dec * patch_w_dis + hps.beta_clf * c_loss
             reset_grad([self.Decoder])
             patch_loss.backward()
             grad_clip([self.Decoder], self.hps.max_grad_norm)
             self.decoder_opt.step()
             info = {
                 f'{flag}/loss_rec': loss_rec.item(),
                 f'{flag}/G_loss_clf': loss_clf.item(),
                 f'{flag}/alpha': current_alpha,
                 f'{flag}/G_acc': acc,
             }
             slot_value = (iteration + 1, hps.iters) + tuple(
                 [value for value in info.values()])
             log = 'G:[%06d/%06d], loss_rec=%.3f, loss_clf=%.2f, alpha=%.2e, acc=%.2f'
             print(log % slot_value)
             if iteration % 100 == 0:
                 for tag, value in info.items():
                     self.logger.scalar_summary(tag, value, iteration + 1)
             if iteration % 1000 == 0 or iteration + 1 == hps.iters:
                 self.save_model(model_path, iteration)
예제 #8
0
    def __init__(self, policy, ob_space, ac_space, nenvs, master_ts = 1, worker_ts = 30,
            ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4, cell = 256,
            alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear',
            algo='regular', beta=1e-3):

        print('Create Session')
        gpu_options = tf.GPUOptions(allow_growth=True)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        nact = ac_space.n
        nbatch = nenvs*master_ts*worker_ts

        A = tf.placeholder(tf.int32, [nbatch])
        ADV = tf.placeholder(tf.float32, [nbatch])
        R = tf.placeholder(tf.float32, [nbatch])
        LR = tf.placeholder(tf.float32, [])

        step_model = policy(sess, ob_space, ac_space, nenvs, 1, 1, cell = cell, model='step_model', algo=algo)
        train_model = policy(sess, ob_space, ac_space, nbatch, master_ts, worker_ts, model='train_model', algo=algo)
        print('model_setting_done')

        #loss construction
        neglogpac = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.wpi, labels=A)
        pg_loss = tf.reduce_mean(ADV * neglogpac)
        vf_loss = tf.reduce_mean(mse(tf.squeeze(train_model.wvf), R))
        entropy = tf.reduce_mean(cat_entropy(train_model.wpi))
        pg_loss = pg_loss - entropy * ent_coef
        print('algo: ', algo, 'max_grad_norm: ', str(max_grad_norm))
        try:
            if algo == 'regular':
                loss = pg_loss + vf_coef * vf_loss
            elif algo == 'VIB':
                '''
                implement VIB here, apart from the vf_loss and pg_loss, there should be a third loss,
                the kl_loss = ds.kl_divergence(model.encoding, prior), where prior is a Gaussian distribution with mu=0, std=1
                the final loss should be pg_loss + vf_coef * vf_loss + beta*kl_loss
                '''
                prior = ds.Normal(0.0, 1.0)
                kl_loss = tf.reduce_mean(ds.kl_divergence(train_model.encoding, prior))
                loss = pg_loss + vf_coef * vf_loss + beta*kl_loss
                # pass
            else:
                raise Exception('Algorithm not exists')
        except Exception as e:
            print(e)

        grads, global_norm = grad_clip(loss, max_grad_norm, ['model'])
        trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)
        _train = trainer.apply_gradients(grads)

        lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)

        def train(wobs, whs, states, rewards, masks, actions, values):
            advs = rewards - values
            for step in range(len(whs)):
                cur_lr = lr.value()

            td_map = {train_model.wX:wobs, A:actions, ADV:advs, R:rewards, LR:cur_lr}
            if states is not None:
                td_map[train_model.wS] = states
                td_map[train_model.wM] = masks

            '''
            you can add and run additional loss for VIB here for debugging, such as kl_loss
            '''
            tloss, value_loss, policy_loss, policy_entropy, _ = sess.run(
                [loss, vf_loss, pg_loss, entropy, _train],
                feed_dict=td_map
            )
            return tloss, value_loss, policy_loss, policy_entropy

        params = find_trainable_variables("model")
        def save(save_path):
            ps = sess.run(params)
            make_path(osp.dirname(save_path))
            joblib.dump(ps, save_path)

        def load(load_path):
            loaded_params = joblib.load(load_path)
            restores = []
            for p, loaded_p in zip(params, loaded_params):
                restores.append(p.assign(loaded_p))
            ps = sess.run(restores)

        self.train_model = train_model
        self.step_model = step_model
        self.step = step_model.step
        self.value = step_model.wvalue
        self.get_wh = step_model.get_wh
        self.initial_state = step_model.w_initial_state
        self.train = train
        self.save = save
        self.load = load
        tf.global_variables_initializer().run(session=sess)
예제 #9
0
    def train(self,
              model_path,
              flag='train',
              mode='train',
              target_guided=False):
        # load hyperparams
        hps = self.hps

        if mode == 'pretrain_AE':
            for iteration in range(hps.enc_pretrain_iters):
                data = next(self.data_loader)
                c, x = self.permute_data(data)

                # encode
                enc_act, enc = self.encode_step(x)
                x_dec = self.decode_step(enc_act, c)
                loss_rec = torch.mean(torch.abs(x_dec - x))
                reset_grad([self.Encoder, self.Decoder])
                loss_rec.backward()
                grad_clip([self.Encoder, self.Decoder], hps.max_grad_norm)
                self.ae_opt.step()

                # tb info
                info = {
                    f'{flag}/pre_loss_rec': loss_rec.item(),
                }
                slot_value = (iteration + 1, hps.enc_pretrain_iters) + tuple(
                    [value for value in info.values()])
                log = 'pre_AE:[%06d/%06d], loss_rec=%.3f'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 'ae', iteration + 1)
            print()

        elif mode == 'pretrain_C':
            for iteration in range(hps.dis_pretrain_iters):

                data = next(self.data_loader)
                c, x = self.permute_data(data)

                # encode
                enc_act, enc = self.encode_step(x)

                # classify speaker
                logits = self.clf_step(enc)
                loss_clf = self.cal_loss(logits, c)

                # update
                reset_grad([self.SpeakerClassifier])
                loss_clf.backward()
                grad_clip([self.SpeakerClassifier], hps.max_grad_norm)
                self.clf_opt.step()

                # calculate acc
                acc = self.cal_acc(logits, c)
                info = {
                    f'{flag}/pre_loss_clf': loss_clf.item(),
                    f'{flag}/pre_acc': acc,
                }
                slot_value = (iteration + 1, hps.dis_pretrain_iters) + tuple(
                    [value for value in info.values()])
                log = 'pre_C:[%06d/%06d], loss_clf=%.2f, acc=%.2f'

                print(log % slot_value, end='\r')
                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 'c', iteration + 1)
            print()

        elif mode == 'train':
            for iteration in range(hps.iters):

                # calculate current alpha
                if iteration < hps.lat_sched_iters:
                    current_alpha = hps.alpha_enc * (iteration /
                                                     hps.lat_sched_iters)
                else:
                    current_alpha = hps.alpha_enc

                #==================train D==================#
                for step in range(hps.n_latent_steps):
                    data = next(self.data_loader)
                    c, x = self.permute_data(data)

                    # encode
                    enc_act, enc = self.encode_step(x)

                    # classify speaker
                    logits = self.clf_step(enc)
                    loss_clf = self.cal_loss(logits, c)
                    loss = hps.alpha_dis * loss_clf

                    # update
                    reset_grad([self.SpeakerClassifier])
                    loss.backward()
                    grad_clip([self.SpeakerClassifier], hps.max_grad_norm)
                    self.clf_opt.step()

                    # calculate acc
                    acc = self.cal_acc(logits, c)
                    info = {
                        f'{flag}/D_loss_clf': loss_clf.item(),
                        f'{flag}/D_acc': acc,
                    }
                    slot_value = (step, iteration + 1, hps.iters) + tuple(
                        [value for value in info.values()])
                    log = 'D-%d:[%06d/%06d], loss_clf=%.2f, acc=%.2f'

                    print(log % slot_value, end='\r')
                    if iteration % 100 == 0:
                        for tag, value in info.items():
                            self.logger.scalar_summary(tag, value,
                                                       iteration + 1)

                #==================train G==================#
                data = next(self.data_loader)
                c, x = self.permute_data(data)

                # encode
                enc_act, enc = self.encode_step(x)

                # decode
                x_dec = self.decode_step(enc_act, c)
                loss_rec = torch.mean(torch.abs(x_dec - x))

                # classify speaker
                logits = self.clf_step(enc)
                acc = self.cal_acc(logits, c)
                loss_clf = self.cal_loss(logits, c)

                # maximize classification loss
                loss = loss_rec - current_alpha * loss_clf
                reset_grad([self.Encoder, self.Decoder])
                loss.backward()
                grad_clip([self.Encoder, self.Decoder], hps.max_grad_norm)
                self.ae_opt.step()

                info = {
                    f'{flag}/loss_rec': loss_rec.item(),
                    f'{flag}/G_loss_clf': loss_clf.item(),
                    f'{flag}/alpha': current_alpha,
                    f'{flag}/G_acc': acc,
                }
                slot_value = (iteration + 1, hps.iters) + tuple(
                    [value for value in info.values()])
                log = 'G:[%06d/%06d], loss_rec=%.3f, loss_clf=%.2f, alpha=%.2e, acc=%.2f'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 's1', iteration + 1)
            print()

        elif mode == 'patchGAN':
            for iteration in range(hps.patch_iters):
                #==================train D==================#
                for step in range(hps.n_patch_steps):

                    data_s = next(self.source_loader)
                    data_t = next(self.target_loader)
                    _, x_s = self.permute_data(data_s)
                    c_t, x_t = self.permute_data(data_t)

                    # encode
                    enc_act, _ = self.encode_step(x_s)

                    # generator
                    x_dec = self.gen_step(enc_act, c_t)

                    # discriminstor
                    w_dis, real_logits, gp = self.patch_step(x_t,
                                                             x_dec,
                                                             is_dis=True)

                    # aux classification loss
                    loss_clf = self.cal_loss(real_logits, c_t, shift=True)

                    loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp
                    reset_grad([self.PatchDiscriminator])
                    loss.backward()
                    grad_clip([self.PatchDiscriminator], hps.max_grad_norm)
                    self.patch_opt.step()

                    # calculate acc
                    acc = self.cal_acc(real_logits, c_t, shift=True)
                    info = {
                        f'{flag}/w_dis': w_dis.item(),
                        f'{flag}/gp': gp.item(),
                        f'{flag}/real_loss_clf': loss_clf.item(),
                        f'{flag}/real_acc': acc,
                    }
                    slot_value = (step, iteration + 1,
                                  hps.patch_iters) + tuple(
                                      [value for value in info.values()])
                    log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f'
                    print(log % slot_value, end='\r')

                    if iteration % 100 == 0:
                        for tag, value in info.items():
                            self.logger.scalar_summary(tag, value,
                                                       iteration + 1)

                #==================train G==================#
                data_s = next(self.source_loader)
                data_t = next(self.target_loader)
                _, x_s = self.permute_data(data_s)
                c_t, x_t = self.permute_data(data_t)

                # encode
                enc_act, _ = self.encode_step(x_s)

                # generator
                x_dec = self.gen_step(enc_act, c_t)

                # discriminstor
                loss_adv, fake_logits = self.patch_step(x_t,
                                                        x_dec,
                                                        is_dis=False)

                # aux classification loss
                loss_clf = self.cal_loss(fake_logits, c_t, shift=True)
                loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv
                reset_grad([self.Generator])
                loss.backward()
                grad_clip([self.Generator], hps.max_grad_norm)
                self.gen_opt.step()

                if target_guided:
                    # teacher forcing
                    enc_tf, _ = self.encode_step(x_t)
                    x_dec_tf = self.gen_step(enc_tf, c_t)
                    loss_rec = torch.mean(torch.abs(x_dec_tf - x_t))
                    reset_grad([self.Generator])
                    loss_rec.backward()
                    self.gen_opt.step()

                # calculate acc
                acc = self.cal_acc(fake_logits, c_t, shift=True)
                info = {
                    f'{flag}/loss_adv': loss_adv.item(),
                    f'{flag}/fake_loss_clf': loss_clf.item(),
                    f'{flag}/fake_acc': acc,
                    f'{flag}/tg_rec':
                    loss_rec.item() if target_guided else 0.000,
                }
                slot_value = (iteration + 1, hps.patch_iters) + tuple(
                    [value for value in info.values()])
                log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f, tg_rec=%.3f'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 's2', iteration + 1)
            print()

        elif mode == 'autolocker':
            criterion = torch.nn.BCELoss()
            for iteration in range(hps.patch_iters):
                #==================train G==================#
                data_s = next(self.source_loader)
                data_t = next(self.target_loader)
                _, x_s = self.permute_data(data_s)
                c_t, x_t = self.permute_data(data_t)

                # encode
                enc_act, _ = self.encode_step(x_s)

                # decode
                residual_output = self.gen_step(enc_act, c_t)

                # re-encode
                re_enc, _ = self.encode_step(residual_output)

                # re-encode loss
                loss_reenc = criterion(re_enc, enc_act.data)
                reset_grad([self.Encoder, self.Decoder, self.Generator])
                loss_reenc.backward()
                grad_clip([self.Generator], hps.max_grad_norm)
                self.gen_opt.step()

                if target_guided:
                    # teacher forcing
                    enc_tf, _ = self.encode_step(x_t)
                    x_dec_tf = self.gen_step(enc_tf, c_t)
                    loss_rec = torch.mean(torch.abs(x_dec_tf - x_t))
                    reset_grad([self.Encoder, self.Decoder, self.Generator])
                    loss_rec.backward()
                    self.gen_opt.step()

                # calculate acc
                info = {
                    f'{flag}/re_enc': loss_reenc.item(),
                    f'{flag}/tg_rec':
                    loss_rec.item() if target_guided else 0.000,
                }
                slot_value = (iteration + 1, hps.patch_iters) + tuple(
                    [value for value in info.values()])
                log = 'patch_G:[%06d/%06d], re_enc=%.3f, tg_rec=%.3f'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 's2', iteration + 1)
            print()

        elif mode == 't_classify':
            for iteration in range(hps.tclf_iters):
                #======train target classifier======#
                data = next(self.data_loader)
                c, x = self.permute_data(data)
                c[c < 100] = 102

                # classification
                logits = self.tclf_step(x)

                # classification loss
                loss = self.cal_loss(logits, c - self.shift_c)
                reset_grad([self.TargetClassifier])
                loss.backward()
                grad_clip([self.TargetClassifier], hps.max_grad_norm)
                self.tclf_opt.step()

                # calculate acc
                acc = self.cal_acc(logits, c - self.shift_c)
                info = {
                    f'{flag}/acc': acc,
                }
                slot_value = (iteration + 1, hps.tclf_iters) + tuple(
                    [value for value in info.values()])
                log = 'Target Classifier:[%05d/%05d], acc=%.2f'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 'tclf', iteration + 1)
            print()

        elif mode == 'train_Tacotron':

            assert self.g_mode == 'tacotron'
            criterion = TacotronLoss()
            self.Encoder.eval()

            for iteration in range(hps.tacotron_iters):
                #======train tacotron======#

                cur_lr = learning_rate_decay(init_lr=0.002,
                                             global_step=iteration)
                for param_group in self.gen_opt.param_groups:
                    param_group['lr'] = cur_lr

                data = next(self.data_loader)
                c, x, m = self.permute_data(data, load_mel=True)

                # encode
                enc_act, enc = self.encode_step(x)

                # tacotron synthesis
                m_dec, x_dec = self.tacotron_step(enc_act.data, m, c)

                # reconstruction loss
                loss_rec = criterion([m_dec, x_dec], [m, x])
                reset_grad([self.Generator])
                loss_rec.backward()
                grad_clip([self.Generator], hps.max_grad_norm)
                self.gen_opt.step()

                # tb info
                info = {
                    f'{flag}/tacotron_loss_rec': loss_rec.item(),
                    f'{flag}/tacotron_lr': cur_lr,
                }
                slot_value = (iteration + 1, hps.tacotron_iters) + tuple(
                    [value for value in info.values()])
                log = 'train_Tacotron:[%06d/%06d], loss_rec=%.3f, lr=%.2e'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 't', iteration + 1)
            print()

        else:
            raise NotImplementedError()
예제 #10
0
    def train(self, model_path, flag='train', mode='train'):
        # load hyperparams
        hps = self.hps

        if mode == 'pretrain_AE':
            for iteration in range(hps.enc_pretrain_iters):
                data = next(self.data_loader)
                c, x = self.permute_data(data)

                # encode
                enc = self.encode_step(x)
                x_tilde = self.decode_step(enc, c)
                loss_rec = torch.mean(torch.abs(x_tilde - x))
                reset_grad([self.Encoder, self.Decoder])
                loss_rec.backward()
                grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
                self.ae_opt.step()

                # tb info
                info = {
                    f'{flag}/pre_loss_rec': loss_rec.item(),
                }
                slot_value = (iteration + 1, hps.enc_pretrain_iters) + \
                    tuple([value for value in info.values()])
                log = 'pre_AE:[%06d/%06d], loss_rec=%.3f'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 'ae', iteration + 1)
            print()

        elif mode == 'pretrain_C':
            for iteration in range(hps.dis_pretrain_iters):

                data = next(self.data_loader)
                c, x = self.permute_data(data)

                # encode
                enc = self.encode_step(x)

                # classify speaker
                logits = self.clf_step(enc)
                loss_clf = self.cal_loss(logits, c)

                # update
                reset_grad([self.SpeakerClassifier])
                loss_clf.backward()
                grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm)
                self.clf_opt.step()

                # calculate acc
                acc = self.cal_acc(logits, c)
                info = {
                    f'{flag}/pre_loss_clf': loss_clf.item(),
                    f'{flag}/pre_acc': acc,
                }
                slot_value = (iteration + 1, hps.dis_pretrain_iters) + \
                    tuple([value for value in info.values()])
                log = 'pre_C:[%06d/%06d], loss_clf=%.2f, acc=%.2f'

                print(log % slot_value, end='\r')
                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 'c', iteration + 1)
            print()

        elif mode == 'train':
            for iteration in range(hps.iters):

                # calculate current alpha
                if iteration < hps.lat_sched_iters:
                    current_alpha = hps.alpha_enc * \
                        (iteration / hps.lat_sched_iters)
                else:
                    current_alpha = hps.alpha_enc

                #==================train D==================#
                for step in range(hps.n_latent_steps):
                    data = next(self.data_loader)
                    c, x = self.permute_data(data)

                    # encode
                    enc = self.encode_step(x)
                    _, z_mean, z_log_var = enc
                    kl_loss = (1 + z_log_var - z_mean**2 -
                               torch.exp(z_log_var)).sum(-1) * -.5
                    kl_loss = kl_loss.sum()
                    # classify speaker
                    logits = self.clf_step(enc)
                    loss_clf = self.cal_loss(logits, c)
                    loss = hps.alpha_dis * loss_clf + kl_loss

                    # update
                    reset_grad([self.SpeakerClassifier])
                    loss.backward()
                    grad_clip([self.SpeakerClassifier], self.hps.max_grad_norm)
                    self.clf_opt.step()

                    # calculate acc
                    acc = self.cal_acc(logits, c)
                    info = {
                        f'{flag}/D_loss_clf': loss_clf.item(),
                        f'{flag}/D_acc': acc,
                    }
                    slot_value = (step, iteration + 1, hps.iters) + \
                        tuple([value for value in info.values()])
                    log = 'D-%d:[%06d/%06d], loss_clf=%.2f, acc=%.2f'

                    print(log % slot_value, end='\r')
                    if iteration % 100 == 0:
                        for tag, value in info.items():
                            self.logger.scalar_summary(tag, value,
                                                       iteration + 1)
                #==================train G==================#
                data = next(self.data_loader)
                c, x = self.permute_data(data)

                # encode
                enc = self.encode_step(x)

                # decode
                x_tilde = self.decode_step(enc, c)
                loss_rec = torch.mean(torch.abs(x_tilde - x))

                # classify speaker
                logits = self.clf_step(enc)
                acc = self.cal_acc(logits, c)
                loss_clf = self.cal_loss(logits, c)

                # maximize classification loss
                loss = loss_rec - current_alpha * loss_clf
                reset_grad([self.Encoder, self.Decoder])
                loss.backward()
                grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
                self.ae_opt.step()

                info = {
                    f'{flag}/loss_rec': loss_rec.item(),
                    f'{flag}/G_loss_clf': loss_clf.item(),
                    f'{flag}/alpha': current_alpha,
                    f'{flag}/G_acc': acc,
                }
                slot_value = (iteration + 1, hps.iters) + \
                    tuple([value for value in info.values()])
                log = 'G:[%06d/%06d], loss_rec=%.3f, loss_clf=%.2f, alpha=%.2e, acc=%.2f'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 's1', iteration + 1)
            print()

        elif mode == 'patchGAN':
            for iteration in range(hps.patch_iters):
                #==================train D==================#
                for step in range(hps.n_patch_steps):

                    data_s = next(self.source_loader)
                    data_t = next(self.target_loader)
                    _, x_s = self.permute_data(data_s)
                    c, x_t = self.permute_data(data_t)

                    # encode
                    enc = self.encode_step(x_s)

                    # sample c
                    c_prime = self.sample_c(x_t.size(0))

                    # generator
                    x_tilde = self.gen_step(enc, c_prime)

                    # discriminstor
                    w_dis, real_logits, gp = self.patch_step(x_t,
                                                             x_tilde,
                                                             is_dis=True)

                    # aux classification loss
                    loss_clf = self.cal_loss(real_logits, c, shift=True)

                    loss = -hps.beta_dis * w_dis + hps.beta_clf * loss_clf + hps.lambda_ * gp
                    reset_grad([self.PatchDiscriminator])
                    loss.backward()
                    grad_clip([self.PatchDiscriminator],
                              self.hps.max_grad_norm)
                    self.patch_opt.step()

                    # calculate acc
                    acc = self.cal_acc(real_logits, c, shift=True)
                    info = {
                        f'{flag}/w_dis': w_dis.item(),
                        f'{flag}/gp': gp.item(),
                        f'{flag}/real_loss_clf': loss_clf.item(),
                        f'{flag}/real_acc': acc,
                    }
                    slot_value = (step, iteration + 1, hps.patch_iters) + \
                        tuple([value for value in info.values()])
                    log = 'patch_D-%d:[%06d/%06d], w_dis=%.2f, gp=%.2f, loss_clf=%.2f, acc=%.2f'
                    print(log % slot_value, end='\r')

                    if iteration % 100 == 0:
                        for tag, value in info.items():
                            self.logger.scalar_summary(tag, value,
                                                       iteration + 1)

                #==================train G==================#
                data_s = next(self.source_loader)
                data_t = next(self.target_loader)
                _, x_s = self.permute_data(data_s)
                c, x_t = self.permute_data(data_t)

                # encode
                enc = self.encode_step(x_s)

                # sample c
                c_prime = self.sample_c(x_t.size(0))

                # generator
                x_tilde = self.gen_step(enc, c_prime)

                # discriminstor
                loss_adv, fake_logits = self.patch_step(x_t,
                                                        x_tilde,
                                                        is_dis=False)

                # aux classification loss
                loss_clf = self.cal_loss(fake_logits, c_prime, shift=True)
                loss = hps.beta_clf * loss_clf + hps.beta_gen * loss_adv
                reset_grad([self.Generator])
                loss.backward()
                grad_clip([self.Generator], self.hps.max_grad_norm)
                self.gen_opt.step()

                # calculate acc
                acc = self.cal_acc(fake_logits, c_prime, shift=True)
                info = {
                    f'{flag}/loss_adv': loss_adv.item(),
                    f'{flag}/fake_loss_clf': loss_clf.item(),
                    f'{flag}/fake_acc': acc,
                }
                slot_value = (iteration + 1, hps.patch_iters) + \
                    tuple([value for value in info.values()])
                log = 'patch_G:[%06d/%06d], loss_adv=%.2f, loss_clf=%.2f, acc=%.2f'
                print(log % slot_value, end='\r')

                if iteration % 100 == 0:
                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, iteration + 1)
                if (iteration + 1) % 1000 == 0:
                    self.save_model(model_path, 's2',
                                    iteration + 1 + hps.iters)
            print()

        else:
            raise NotImplementedError()
예제 #11
0
    def __init__(self, policy, ob_space, ac_space, nenvs, master_ts = 1, worker_ts = 8,
                 ent_coef=0.01, vf_coef=0.5, max_grad_norm=2.5, lr=7e-4, cell=256,
                 ib_alpha=0.04, sv_M=32, algo='use_svib_uniform',
                 alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear'):

        sess = tf_util.make_session()
        nact = ac_space.n
        nbatch = nenvs*master_ts*worker_ts # master what's mean?

        # A:action, ADV:advantage, R:reward, LR:Learning Rate
        A = tf.placeholder(tf.int32, [nbatch])
        ADV = tf.placeholder(tf.float32, [nbatch])
        R = tf.placeholder(tf.float32, [nbatch])
        LR = tf.placeholder(tf.float32, [])

        step_model = policy(sess, ob_space, ac_space, nenvs, 1, 1, cell=cell, M=sv_M, model='step_model', algo=algo)
        train_model = policy(sess, ob_space, ac_space, nbatch, master_ts, worker_ts, cell = cell, M=sv_M, model='train_model', algo=algo)
        print('model_setting_done, algorithm:', str(algo))

        '''
        可视化互信息,暂时跳过
        '''
        ib_loss = train_model.mi_xh_loss
        T = train_model.T_value
        t_grads, t_global_norm = grad_clip(-vf_coef*ib_loss, max_grad_norm, ['model/T/update_params'])
        t_trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)
        _t_train = t_trainer.apply_gradients(t_grads)
        T_update_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='model/T/update_params')
        T_orig_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='model/T/orig_params')
        reset_update_params = [update_param.assign(orig_param) for update_param, orig_param in zip(T_update_params, T_orig_params)]

        # rpf_matrix, rpf_grads = rpf_kernel(vf_loss_sv, rpf_h)

        if algo == 'use_svib_uniform' or algo == 'use_svib_gaussian':
            def expand_placeholder(X, M=sv_M):
                return tf.tile(tf.expand_dims(X, axis=-1), [1, M])
            A_expand, R_expand = expand_placeholder(A), expand_placeholder(R)
            neglogpac_expand = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.wpi_expand, labels=A_expand)#shape=[nbatch, sv_M]
            # pg_loss_expand = tf.reduce_mean(ADV_expand * neglogpac_expand, axis=-1)
            pg_loss_expand = tf.reduce_mean(tf.stop_gradient(R_expand-train_model.wvf_expand[:,:,0]) * neglogpac_expand, axis=-1)
            vf_loss_expand = tf.reduce_mean(mse(tf.squeeze(train_model.wvf_expand), R_expand), axis=-1)
            entropy_expand = tf.reduce_mean(cat_entropy(train_model.wpi_expand), axis=-1)#shape=[nbatch]
            J_theta = -(pg_loss_expand + vf_coef*vf_loss_expand - ent_coef*entropy_expand)

            loss_expand = -J_theta / float(nbatch)
            pg_loss_expand_ = tf.reduce_mean(pg_loss_expand)
            vf_loss_expand_ = tf.reduce_mean(vf_loss_expand)
            entropy_expand_ = tf.reduce_mean(entropy_expand)
            loss_expand_ = -tf.reduce_mean(J_theta)

            print('ib_alpha: ', ib_alpha)
            log_p_grads = tf.gradients(J_theta/np.sqrt(ib_alpha), [train_model.wh_expand])[0]#shape=[nbatch, sv_M, cell]
            if algo == 'use_svib_gaussian':
                mean, var = tf.nn.moments(train_model.wh_expand, axes=1, keep_dims=True)#shape=[nbatch, 1,cell]
                gaussian_grad = -(train_model.wh_expand - mean)/(float(sv_M) * (var+1e-3))
                log_p_grads += 5e-3*(tf_l2norm(log_p_grads, axis=-1, keep_dims=True)/tf_l2norm(gaussian_grad, axis=-1, keep_dims=True))*gaussian_grad
            sv_grads = tf.constant(0., tf.float32, shape=[nbatch, 0, cell])
            for i in range(sv_M):
                sv_grad = tf.reduce_sum(train_model.rpf_matrix[:, :, i:i+1] * log_p_grads, axis=1) + np.sqrt(ib_alpha)*train_model.rpf_grads[:, i, :]#shape=[nbatch, cell]
                sv_grads = tf.concat([sv_grads, tf.expand_dims(sv_grad, axis=1)], axis=1)
                
            SV_GRADS = tf.placeholder(tf.float32, [nbatch, sv_M, cell])
            repr_loss = tf.reduce_mean(SV_GRADS * train_model.wh_expand, axis=1)#shape=[nbatch,cell]
            repr_loss = -tf.reduce_mean(tf.reduce_sum(repr_loss, axis=-1))#max optimization problem to minimization problem
            # repr_loss = -tf.reduce_mean(repr_loss, axis=0)

            # sv_grad_ = tf.reduce_sum(train_model.rpf_matrix[:, :, 2:3] * log_p_grads, axis=1) + train_model.rpf_grads[:, 2, :]
            # exploit_term = tf.reduce_sum(train_model.rpf_matrix[:, :, 2:3] * log_p_grads, axis=1)
            # explore_term = train_model.rpf_grads[:, 2, :]
            grads_expand, global_norm_expand = grad_clip(loss_expand, max_grad_norm, ['model/worker_module'])
            trainer_expand = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)
            _train_expand = trainer_expand.apply_gradients(grads_expand)

            repr_grads, repr_global_norm = grad_clip(repr_loss, max_grad_norm, ['model/ordinary_encoder'])
            repr_trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)
            _repr_train = repr_trainer.apply_gradients(repr_grads)

        elif algo == 'sv_a2c':
            def expand_placeholder(X, M=sv_M):
                return tf.tile(tf.expand_dims(X, axis=-1), [1, M])
            A_expand, R_expand = expand_placeholder(A), expand_placeholder(R) # [40, 32]
            sigma = tf.constant(1e-5)
            neglogpac_expand = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.wpi_expand, labels=A_expand) + sigma # [40, 32]
            pg_loss_expand = tf.reduce_mean(tf.stop_gradient(R_expand - train_model.wvf_expand[:, :, 0]) * neglogpac_expand, axis=-1) # [40, ]
            vf_loss_sv = tf.expand_dims(mse(tf.squeeze(train_model.wvf_expand), R_expand), axis=-1) # [40, 32, 1]
            vf_loss_expand = tf.reduce_mean(mse(tf.squeeze(train_model.wvf_expand), R_expand), axis=-1) # [40, ]
            entropy_expand = tf.reduce_mean(cat_entropy(train_model.wpi_expand), axis=-1)  # shape=[nbatch]

            J_theta = pg_loss_expand + vf_coef * vf_loss_expand - ent_coef * entropy_expand # [40, ]
            # 为什么要除nbatch
            loss_expand = J_theta / float(nbatch) # [40, ]

            pg_loss_expand_ = tf.reduce_mean(pg_loss_expand)
            vf_loss_expand_ = tf.reduce_mean(vf_loss_expand) # [1]
            entropy_expand_ = tf.reduce_mean(entropy_expand)
            loss_expand_ = tf.reduce_mean(J_theta)

            print('ib_alpha: ', ib_alpha)
            # mean, var = tf.constant(0., tf.float32, [nbatch, 1, 1]), tf.constant(1, tf.float32, [nbatch, 1, 1])
            mean, var = tf.nn.moments(vf_loss_sv, axes=1, keep_dims=True) # [40, 1, 1]
            # Problem1: guassian gradient cauculate problem
            log_p_grads = -(vf_loss_sv - mean) / (float(sv_M) * (var))

            sv_grads = tf.constant(0., tf.float32, shape=[nbatch, 0, 1]) # [nbatch, m, 1]

            rpf_h = self.h_coef(vf_loss_sv, sv_M)
            rpf_matrix, rpf_grads = self.rpf_kernel(vf_loss_sv, rpf_h, sv_M)
            for i in range(sv_M):
                # sv_grad = tf.reduce_sum(train_model.rpf_matrix[:, :, i:i+1] * log_p_grads, axis=1) + sqrt(ib_alpha) * train_model.rpf_grads[:, i, :] #shape=[nbatch, cell]
                sv_grad = tf.reduce_sum(rpf_matrix[:, :, i:i + 1] * log_p_grads, axis=1) + rpf_grads[:, i, :]
                sv_grads = tf.concat([sv_grads, tf.expand_dims(sv_grad, axis=1)], axis=1)

            SV_GRADS = tf.placeholder(tf.float32, [nbatch, sv_M, 1])
            sv_loss = tf.reduce_mean(SV_GRADS * vf_loss_sv, axis=1)

            loss_expand -=  ib_alpha * (tf_l2norm(loss_expand, axis=-1, keep_dims=True)/tf_l2norm(sv_loss, axis=-1, keep_dims=True)) * sv_loss

            grads_expand, global_norm_expand = grad_clip(loss_expand, max_grad_norm, ['model/worker_module', 'model/ordinary_encoder'])
            trainer_expand = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)
            _train_expand = trainer_expand.apply_gradients(grads_expand)

            # sv_loss_grads, sv_global_norm = grad_clip(sv_loss, max_grad_norm, ['model/worker_module/comm', 'model/worker_module/w_value'])
            # sv_trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)
            # _sv_train = sv_trainer.apply_gradients(sv_loss_grads)


        elif algo == 'anchor':
            neglogpac = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.wpi, labels=A)
            pg_loss = tf.reduce_mean(ADV * neglogpac)
            entropy = tf.reduce_mean(cat_entropy(train_model.wpi))
            vf_loss = tf.reduce_mean(mse(tf.squeeze(train_model.wvf), R))

            # anchor method
            param_list = []
            for scope in ['model/worker_module']:
                List = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
                print(len(List))
                param_list += List

            param_value_layer1_w = param_list[0]
            param_value_layer1_b = param_list[1]
            param_policy_layer2_w = param_list[2]
            param_value_layer2_w = param_list[4]
            param_value_layer2_b = param_list[5]

            init_stddev = 5.0 # 7.0
            init_stddev_2 = 0.18 / np.sqrt(cell)  # normal scaling
            lambda_anchor = [0.000001,0.1]

            layer1_w_init = tf.random_normal(mean=0., stddev=init_stddev, shape=param_value_layer1_w.get_shape())
            layer1_b_init = tf.random_normal(mean=0., stddev=init_stddev, shape=param_value_layer1_b.get_shape())
            layer2_w_init = tf.random_normal(mean=0, stddev=init_stddev_2, shape=param_value_layer2_w.get_shape())
            layer2_b_init = tf.random_normal(mean=0, stddev=init_stddev_2, shape=param_value_layer2_b.get_shape())

            loss_anchor = lambda_anchor[0] / nbatch * tf.reduce_sum(tf.square(layer1_w_init - param_value_layer1_w))
            loss_anchor += lambda_anchor[0] / nbatch * tf.reduce_sum(tf.square(layer1_b_init - param_value_layer1_b))
            loss_anchor += lambda_anchor[1] / nbatch * tf.reduce_sum(tf.square(layer2_w_init - param_value_layer2_w))
            loss_anchor += lambda_anchor[1] / nbatch * tf.reduce_sum(tf.square(layer2_b_init - param_value_layer2_b))

            loss = pg_loss + vf_coef * vf_loss - ent_coef * entropy + loss_anchor

            grads, global_norm = grad_clip(loss, max_grad_norm, ['model/worker_module', 'model/ordinary_encoder'])
            trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)
            _train = trainer.apply_gradients(grads)

        else: # regular algorithm
            neglogpac = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.wpi, labels=A)
            pg_loss = tf.reduce_mean(ADV * neglogpac)
            entropy = tf.reduce_mean(cat_entropy(train_model.wpi))
            vf_loss = tf.reduce_mean(mse(tf.squeeze(train_model.wvf), R))

            loss = pg_loss + vf_coef * vf_loss - ent_coef * entropy

            grads, global_norm = grad_clip(loss, max_grad_norm, ['model/worker_module', 'model/ordinary_encoder'])
            trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)
            _train = trainer.apply_gradients(grads)

        params = find_trainable_variables("model")
        lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)


        def train(wobs, whs, states, rewards, masks, actions, values, noises):
            advs = rewards - values
            # adv_mu, adv_var = np.mean(advs), np.var(advs)+1e-3
            # advs = (advs - adv_mu) / adv_var

            for step in range(len(whs)):
                cur_lr = lr.value()
            sv_td_map = {train_model.wX : wobs, train_model.istraining:True, A:actions, R:rewards, LR:cur_lr}

            # Sess Graph
            # writer = tf.summary.FileWriter('./', sess.graph)

            repr_td_map = {train_model.wX: wobs, train_model.istraining: True, A: actions, R: rewards, LR: cur_lr}
            rl_td_map = {train_model.wX : wobs, train_model.istraining: True, A:actions, ADV:advs, R:rewards, LR:cur_lr}
            if states is not None:
                rl_td_map[train_model.wS] = states
                rl_td_map[train_model.wM] = masks
            repr_grad_norm = 0.
            # print(str(np.sum(whs-sess.run(train_model.wh, feed_dict={train_model.wX : wobs, train_model.istraining:True, train_model.noise:noises}))))
            if algo == 'use_svib_uniform' or algo == 'use_svib_gaussian':
                repr_td_map[train_model.noise_expand], repr_td_map[train_model.NOISE_KEEP] = sess.run(train_model.noise_expand), noises
                wh_expands, sv_gradients = sess.run([train_model.wh_expand, sv_grads], feed_dict=repr_td_map)
                rl_td_map[train_model.wh_expand] = wh_expands
                tloss, value_loss, policy_loss, policy_entropy, rl_grad_norm, _ = sess.run(
                    [loss_expand_, vf_loss_expand_, pg_loss_expand_, entropy_expand_, global_norm_expand, _train_expand],
                    feed_dict=rl_td_map
                )
                repr_td_map[SV_GRADS] = sv_gradients
                # if algo == 'use_svib_gaussian':
                #     gaussian_gradients, repr_grad_norm, __ =\
                #         sess.run([gaussian_grad, repr_global_norm, _repr_train], feed_dict=repr_td_map)
                #     return tloss, value_loss, policy_loss, policy_entropy, rl_grad_norm, gaussian_gradients, repr_grad_norm  # represnet_loss, SV_GRAD, EXPLOIT, LOG_P_GRADS, EXPLORE
                repr_grad_norm, represent_loss, __ = sess.run([repr_global_norm, repr_loss, _repr_train], feed_dict=repr_td_map)

            elif algo == 'anchor':
                rl_td_map[train_model.wX] = wobs
                tloss, value_loss, policy_loss, policy_entropy, rl_grad_norm, anchor_loss, _ = sess.run(
                    [loss, vf_loss, pg_loss, entropy, global_norm, loss_anchor ,_train],
                    feed_dict=rl_td_map
                )

                represent_loss = 0.
                sv_loss_ = 0
            elif algo == 'sv_a2c':
                sv_td_map[train_model.noise_expand], sv_td_map[train_model.NOISE_KEEP] = sess.run(
                    train_model.noise_expand), noises
                wvf_expands, sv_gradients = sess.run([train_model.wvf_expand, sv_grads], feed_dict=sv_td_map)
                rl_td_map[train_model.wvf_expand] = wvf_expands
                rl_td_map[train_model.noise_expand], rl_td_map[train_model.NOISE_KEEP] = sess.run(
                    train_model.noise_expand), noises
                rl_td_map[SV_GRADS] = sv_gradients
                tloss, value_loss, policy_loss, policy_entropy, rl_grad_norm, sv_loss_, _ = sess.run(
                    [loss_expand_, vf_loss_expand_, pg_loss_expand_, entropy_expand_, global_norm_expand,
                     sv_loss, _train_expand],
                    feed_dict=rl_td_map
                )
                sv_td_map[SV_GRADS] = sv_gradients
                anchor_loss = 0.
                represent_loss = 0.

            else:
                rl_td_map[train_model.wX], rl_td_map[train_model.noise] = wobs, noises#noise won't be used when algo is 'regular'
                tloss, value_loss, policy_loss, policy_entropy, rl_grad_norm, _ = sess.run(
                    [loss, vf_loss, pg_loss, entropy, global_norm, _train],
                    feed_dict=rl_td_map
                )
                # repr_td_map[WH_GRADS] = wh_gradients
                # repr_grad_norm, __ = sess.run([ordin_repr_global_norm, _ordin_repr_train], feed_dict=repr_td_map)
                repr_grad_norm = 0.
                represent_loss = 0.
                anchor_loss = 0
                sv_loss_ = 0
            return tloss, value_loss, policy_loss, policy_entropy, rl_grad_norm, repr_grad_norm, represent_loss, anchor_loss, sv_loss_#SV_GRAD, EXPLOIT, LOG_P_GRADS, EXPLORE

        def train_mine(wobs, whs, steps=256, lr=7e-4):
            # whs_std = (whs-np.mean(whs,axis=0,keepdims=True))/(1e-8 + np.std(whs,axis=0,keepdims=True))
            idx = np.arange(len(whs))
            ___ = sess.run(reset_update_params)
            for i in range(int(steps)):
                np.random.shuffle(idx)
                mi, T_value, __ = sess.run([ib_loss, T, _t_train],
                                           feed_dict={train_model.wX: wobs[idx], train_model.wh: whs[idx],
                                                      LR: lr, train_model.istraining: True})
            logger.record_tabular('mutual_info_loss', float(mi))
            logger.record_tabular('T_value', float(T_value))
            logger.dump_tabular()

        def save(save_path):
            ps = sess.run(params)
            make_path(osp.dirname(save_path))
            joblib.dump(ps, save_path)

        def load(load_path):
            loaded_params = joblib.load(load_path)
            restores = []
            for p, loaded_p in zip(params, loaded_params):
                restores.append(p.assign(loaded_p))
            ps = sess.run(restores)

        self.train = train
        self.train_mine = train_mine
        self.train_model = train_model
        self.step_model = step_model
        self.get_wh = step_model.get_wh
        self.get_noise = step_model.get_noise
        self.value = step_model.wvalue
        self.step = step_model.step
        self.initial_state = step_model.w_initial_state
        self.save = save
        self.load = load
        self.sv_M = sv_M
        # self.rpf_h = rpf_h
        # self.rpf_matrix = rpf_matrix
        # self.rpf_grads = rpf_grads
        tf.global_variables_initializer().run(session=sess)
예제 #12
0
 def train(self, model_path, flag='train'):
     # load hyperparams
     hps = self.hps
     for iteration in range(hps.iters):
         # calculate current alpha
         if iteration + 1 < hps.lat_sched_iters and iteration >= hps.enc_pretrain_iters:
             current_alpha = hps.alpha_enc * (
                 iteration + 1 - hps.enc_pretrain_iters) / (
                     hps.lat_sched_iters - hps.enc_pretrain_iters)
         else:
             current_alpha = 0
         if iteration >= hps.enc_pretrain_iters:
             n_latent_steps = hps.n_latent_steps \
                 if iteration > hps.enc_pretrain_iters else hps.dis_pretrain_iters
             for step in range(n_latent_steps):
                 #===================== Train latent discriminator =====================#
                 data = next(self.data_loader)
                 (c_i, c_j), (x_i_t, x_i_tk, x_i_prime,
                              x_j) = self.permute_data(data)
                 # encode
                 enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step(
                     x_i_t, x_i_tk, x_i_prime, x_j)
                 # latent discriminate
                 latent_w_dis, latent_gp = self.latent_discriminate_step(
                     enc_i_t, enc_i_tk, enc_i_prime, enc_j)
                 lat_loss = -hps.alpha_dis * latent_w_dis + hps.lambda_ * latent_gp
                 reset_grad([self.LatentDiscriminator])
                 lat_loss.backward()
                 grad_clip([self.LatentDiscriminator],
                           self.hps.max_grad_norm)
                 self.lat_opt.step()
                 # print info
                 info = {
                     f'{flag}/D_latent_w_dis': latent_w_dis.data[0],
                     f'{flag}/latent_gp': latent_gp.data[0],
                 }
                 slot_value = (step, iteration + 1, hps.iters) + \
                         tuple([value for value in info.values()])
                 log = 'lat_D-%d:[%06d/%06d], w_dis=%.3f, gp=%.2f'
                 print(log % slot_value)
                 for tag, value in info.items():
                     self.logger.scalar_summary(tag, value, iteration)
         # two stage training
         if iteration >= hps.patch_start_iter:
             for step in range(hps.n_patch_steps):
                 #===================== Train patch discriminator =====================#
                 data = next(self.data_loader)
                 (c_i, _), (x_i_t, _, _, _) = self.permute_data(data)
                 # encode
                 enc_i_t, = self.encode_step(x_i_t)
                 c_sample = self.sample_c(x_i_t.size(0))
                 x_tilde = self.decode_step(enc_i_t, c_i)
                 # Aux classify loss
                 patch_w_dis, real_logits, fake_logits, patch_gp = \
                         self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=True)
                 patch_loss = -hps.beta_dis * patch_w_dis + hps.lambda_ * patch_gp + hps.beta_clf * c_loss
                 reset_grad([self.PatchDiscriminator])
                 patch_loss.backward()
                 grad_clip([self.PatchDiscriminator],
                           self.hps.max_grad_norm)
                 self.patch_opt.step()
                 # print info
                 info = {
                     f'{flag}/D_patch_w_dis': patch_w_dis.data[0],
                     f'{flag}/patch_gp': patch_gp.data[0],
                     f'{flag}/c_loss': c_loss.data[0],
                     f'{flag}/real_acc': real_acc,
                     f'{flag}/fake_acc': fake_acc,
                 }
                 slot_value = (step, iteration + 1, hps.iters) + \
                         tuple([value for value in info.values()])
                 log = 'patch_D-%d:[%06d/%06d], w_dis=%.3f, gp=%.2f, c_loss=%.3f, real_acc=%.2f, fake_acc=%.2f'
                 print(log % slot_value)
                 for tag, value in info.items():
                     self.logger.scalar_summary(tag, value, iteration)
         #===================== Train G =====================#
         data = next(self.data_loader)
         (c_i, c_j), (x_i_t, x_i_tk, x_i_prime,
                      x_j) = self.permute_data(data)
         # encode
         enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step(
             x_i_t, x_i_tk, x_i_prime, x_j)
         # decode
         x_tilde = self.decode_step(enc_i_t, c_i)
         loss_rec = torch.mean(torch.abs(x_tilde - x_i_t))
         # latent discriminate
         loss_adv = self.latent_discriminate_step(enc_i_t,
                                                  enc_i_tk,
                                                  enc_i_prime,
                                                  enc_j,
                                                  is_dis=False)
         ae_loss = loss_rec + current_alpha * loss_adv
         reset_grad([self.Encoder, self.Decoder])
         retain_graph = True if hps.n_patch_steps > 0 else False
         ae_loss.backward(retain_graph=retain_graph)
         grad_clip([self.Encoder, self.Decoder], self.hps.max_grad_norm)
         self.ae_opt.step()
         info = {
             f'{flag}/loss_rec': loss_rec.data[0],
             f'{flag}/loss_adv': loss_adv.data[0],
             f'{flag}/alpha': current_alpha,
         }
         slot_value = (iteration + 1, hps.iters) + tuple(
             [value for value in info.values()])
         log = 'G:[%06d/%06d], loss_rec=%.2f, loss_adv=%.2f, alpha=%.2e'
         print(log % slot_value)
         for tag, value in info.items():
             self.logger.scalar_summary(tag, value, iteration + 1)
         # patch discriminate
         if hps.n_patch_steps > 0 and iteration >= hps.patch_start_iter:
             c_sample = self.sample_c(x_i_t.size(0))
             x_tilde = self.decode_step(enc_i_t, c_sample)
             patch_w_dis, real_logits, fake_logits = \
                     self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=False)
             patch_loss = hps.beta_dec * patch_w_dis + hps.beta_clf * c_loss
             reset_grad([self.Decoder])
             patch_loss.backward()
             grad_clip([self.Decoder], self.hps.max_grad_norm)
             self.decoder_opt.step()
             info = {
                 f'{flag}/G_patch_w_dis': patch_w_dis.data[0],
                 f'{flag}/c_loss': c_loss.data[0],
                 f'{flag}/real_acc': real_acc,
                 f'{flag}/fake_acc': fake_acc,
             }
             slot_value = (iteration + 1, hps.iters) + tuple(
                 [value for value in info.values()])
             log = 'G:[%06d/%06d]: patch_w_dis=%.2f, c_loss=%.2f, real_acc=%.2f, fake_acc=%.2f'
             print(log % slot_value)
             for tag, value in info.items():
                 self.logger.scalar_summary(tag, value, iteration + 1)
         if iteration % 1000 == 0 or iteration + 1 == hps.iters:
             self.save_model(model_path, iteration)