Exemplo n.º 1
0
def test_discriminator():
    from logger import MyLogger
    mylogger = MyLogger('./log')

    configs = config.configs['gail']
    print(tf.get_default_session())
    with tf.Session() as sess:
        d = Discriminator(arch_params=configs.discriminator_params,
                          stddev=0.02)
        s1 = np.array([[1, 2, 3, 4], [1, 2, 3, 4]]).reshape([-1, 4])
        a1 = np.array([[1, -1], [1, -1]]).reshape([-1, 2])

        s2 = np.array([8, -6, 5, 7]).reshape([-1, 4])
        a2 = np.array([1, 1]).reshape([-1, 2])

        is_training = tf.placeholder(tf.bool)
        e_s = tf.placeholder(dtype=tf.float32,
                             shape=list(s1.shape),
                             name='e_s')
        e_a = tf.placeholder(dtype=tf.float32,
                             shape=list(a1.shape),
                             name='e_a')
        g_s = tf.placeholder(dtype=tf.float32,
                             shape=list(s2.shape),
                             name='g_s')
        g_a = tf.placeholder(dtype=tf.float32,
                             shape=list(a2.shape),
                             name='g_a')
        # print([None]+list(s1.shape))

        e_output = d(state=e_s, action=e_a, is_training=is_training)
        g_output = d(state=g_s,
                     action=g_a,
                     is_training=is_training,
                     reuse=True)

        discriminator_loss = -tf.reduce_mean(
            tf.log(e_output + configs.epsilon) +
            tf.log(1 - g_output + configs.epsilon))

        # # tf.GraphKeys.UPDATE_OPS ,tf.GraphKeys.TRAINABLE_VARIABLES
        # with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        #     discriminator_train_step = tf.train.AdamOptimizer(configs.learning_rate, configs.beta1,
        #                                                       configs.beta2).minimize(discriminator_loss)

        discriminator_train_step = tf.train.AdamOptimizer(
            configs.learning_rate, configs.beta1, configs.beta2).minimize(
                discriminator_loss,
                var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                           scope='discriminator'))

        tf.global_variables_initializer().run(session=sess)

        for i in range(100):
            eoutput, goutput, _, dloss = sess.run([
                e_output, g_output, discriminator_train_step,
                discriminator_loss
            ],
                                                  feed_dict={
                                                      e_s: s1,
                                                      e_a: a1,
                                                      g_s: s2,
                                                      g_a: a2,
                                                      is_training: False
                                                  })
            s1 += 0
            a1 += 0
            mylogger.write_summary_scalar(iteration=i,
                                          tag='dloss',
                                          value=dloss)
            mylogger.write_summary_scalar(iteration=i, tag='loss', value=dloss)
Exemplo n.º 2
0
def learn(*, policy, env, buffer_size, nminibatches, total_timesteps, ent_coef,
          lr, vf_coef, max_grad_norm, gamma, lam, log_interval, noptepochs,
          cliprange, save_interval, bc, bc_steps, algo, log_dir, save_path):

    if isinstance(lr, float):
        lr = constfn(lr)
    else:
        assert callable(lr)  # 方法用来检测对象是否可被调用
    if isinstance(cliprange, float):
        cliprange = constfn(cliprange)
    else:
        assert callable(cliprange)

    total_timesteps = int(total_timesteps)
    ob_space = env.observation_space
    ac_space = env.action_space
    nbatch_train = buffer_size // nminibatches
    mylogger = MyLogger(log_dir=log_dir)
    if algo == 'gail':
        sampler = Sampler(buffer_size=buffer_size)

    sess = tf.Session()
    make_model = lambda: Model(sess=sess,
                               policy=policy,
                               ob_space=ob_space,
                               ac_space=ac_space,
                               nbatch_act=1,
                               nbatch_train=nbatch_train,
                               ent_coef=ent_coef,
                               vf_coef=vf_coef,
                               max_grad_norm=max_grad_norm)

    model = make_model()
    runner = Runner(sess=sess,
                    env=env,
                    model=model,
                    buffer_size=buffer_size,
                    gamma=gamma,
                    lam=lam,
                    algo=algo)
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(max_to_keep=10)
    load(sess=sess, saver=saver, save_path=save_path)
    mylogger.add_sess_graph(sess.graph)
    epinfobuf = deque(maxlen=100)

    nupdates = total_timesteps // buffer_size  # 总更新轮数为最大采样步长//buffer_size

    buffer_name = [
        'obs', 'returns', 'masks', 'actions', 'values', 'neglogpacs', 'states',
        'epinfos', 'ep_r', 'ep_count'
    ]
    buffer = dict(zip(buffer_name, [[] for _ in range(len(buffer_name))]))

    # behavior cloning
    if bc == True:
        for i in range(1, bc_steps + 1):
            print("bc iter: ", i)
            mse = 0
            for epoch in range(100):
                mb_obs, mb_act = sampler.bc_sample(batch_size=512)
                mse += model.behavior_clone(obs=mb_obs, actions=mb_act)
            print("mse", mse)
            mylogger.write_summary_scalar(iteration=i, tag="bc_mse", value=mse)
            if i % 20 == 0:
                save(sess=sess,
                     saver=saver,
                     save_path=save_path + "bc",
                     global_step=i)

    # training
    for update in range(1, nupdates + 1):
        frac = 1.0 - (update - 1.0) / nupdates
        lrnow = lr(frac)
        cliprangenow = cliprange(frac)
        assert buffer_size % nminibatches == 0

        inds = np.arange(buffer_size)
        np.random.shuffle(inds)
        states = runner.states
        mblossvals = []

        # 设置generator和disciminator的更新比
        critic_g = 2
        critic_d = 5
        if update % 20 == 0 or update < 50:
            critic_g = 2
            critic_d = 20

        if states is None:
            # 非rnn policy
            for i in range(critic_g):
                inds = np.arange(buffer_size)
                # 采样一个buffer的数据
                bf_temp = obs, returns, masks, actions, values, neglogpacs, states, epinfos, ep_r, ep_count = runner.run(
                )
                for key_idx in range(buffer_name.__len__()):
                    # 将数据保存在临时的buffer中用来更新Discriminator
                    buffer[buffer_name[key_idx]].append(bf_temp[key_idx])
                mylogger.write_summary_scalar(update, 'epr_sum', ep_r)
                mylogger.write_summary_scalar(update, 'nums of episodes',
                                              ep_count)
                mylogger.write_summary_scalar(update, 'epr_mean',
                                              ep_r // ep_count)
                epinfobuf.extend(epinfos)
                for _ in range(noptepochs):
                    # 一个buffer的数据更新noptepochs次
                    np.random.shuffle(inds)
                    for start in range(0, buffer_size, nbatch_train):
                        # 从buffer中随机采样batch
                        end = start + nbatch_train
                        mbinds = inds[start:end]
                        slices = (arr[mbinds]
                                  for arr in (obs, returns, masks, actions,
                                              values, neglogpacs))
                        mblossvals.append(
                            model.train(lrnow, cliprangenow, *slices))
        else:
            # rnn policy
            nenvs = 100
            for i in range(4):
                bf_temp = obs, returns, masks, actions, values, neglogpacs, states, epinfos, ep_r, ep_count = runner.run(
                )
                for key_idx in range(buffer_name.__len__()):
                    buffer[buffer_name[key_idx]].append(bf_temp[key_idx])
                assert nenvs % nminibatches == 0
                envsperbatch = nenvs // nminibatches
                envinds = np.arange(nenvs)
                flatinds = np.arange(buffer_size).reshape(nenvs, -1)
                for _ in range(noptepochs):
                    np.random.shuffle(envinds)
                    for start in range(0, nenvs, envsperbatch):
                        end = start + envsperbatch
                        mbenvinds = envinds[start:end]
                        mbflatinds = flatinds[mbenvinds].ravel()
                        slices = (arr[mbflatinds]
                                  for arr in (obs, returns, masks, actions,
                                              values, neglogpacs))
                        mbstates = np.array([model.initial_state] *
                                            envsperbatch).reshape(
                                                [envsperbatch, -1])
                        mblossvals.append(
                            model.train(lrnow, cliprangenow, *slices,
                                        mbstates))

        if algo == 'ppo':
            pass
        else:
            # gail中训练Discriminator
            for i in range(critic_d):
                expert_s, expert_a = sampler.random_sample()
                gen_s, returns, masks, gen_a, values, neglogpacs, states, epinfos, ep_r, ep_count \
                    = sample_from_buffer(buffer=buffer)
                for _ in range(5):
                    runner.discriminator.train(sess, expert_s, expert_a, gen_s,
                                               gen_a)

                expert_rewards = runner.discriminator.get_rewards_e(
                    sess, expert_s, expert_a)
                gen_rewards = runner.discriminator.get_rewards(
                    sess, gen_s, gen_a)
                gan_loss = runner.discriminator.get_ganLoss(
                    sess, expert_s, expert_a, gen_s, gen_a)
                mylogger.write_summary_scalar(update, 'expert_reward mean',
                                              np.mean(expert_rewards))
                mylogger.write_summary_scalar(update, 'gen_rewards mean',
                                              np.mean(gen_rewards))
                mylogger.write_summary_scalar(update, 'discrinator loss',
                                              np.mean(gan_loss))

        buffer_clear(buffer)
        lossvals = np.mean(mblossvals, axis=0)
        mblossvals.clear()
        '''pg_loss, vf_loss, entropy'''
        mylogger.write_summary_scalar(update, "pg_loss", lossvals[0])
        mylogger.write_summary_scalar(update, "vf_loss", lossvals[1])
        mylogger.write_summary_scalar(update, "entropy", lossvals[2])
        mylogger.write_summary_scalar(update, "surrogate loss", lossvals[3])
        mylogger.write_summary_scalar(update, 'critic_d', critic_d)
        mylogger.write_summary_scalar(update, 'critic_g', critic_g)
        if save_interval and (update % save_interval == 0 or update == 1):
            save(global_step=update + bc_steps,
                 saver=saver,
                 sess=sess,
                 save_path=save_path)
    sess.close()
    env.close()