示例#1
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  env = gym.make('Pendulum-v0')

  behavior_policy = lambda obs: env.action_space.sample()
  # null_policy = lambda obs: np.zeros(env.action_space.shape)

  num_episodes = 100
  episode_length = 200

  memory = replay.Memory()

  for _ in range(num_episodes):
    # collect a trajectory
    obs = env.reset()
    memory.log_init(obs)

    for _ in range(episode_length):
      act = behavior_policy(obs)
      next_obs, reward, term, _ = env.step(act)
      memory.log_experience(obs, act, reward, next_obs)
      if term:
        break
      obs = next_obs

  s = memory.serialize()

  # Save pickle file
  with open(os.path.join(FLAGS.outdir, 'pendulum.pickle'), 'wb') as f:
    f.write(s)

  # Sanity check serialization.
  m2 = replay.Memory()
  m2.unserialize(s)
  print(np.array_equal(m2.entered_states(), memory.entered_states()))
  print(np.array_equal(m2.exited_states(), memory.exited_states()))
  print(np.array_equal(m2.attempted_actions(), memory.attempted_actions()))
  print(np.array_equal(m2.observed_rewards(), memory.observed_rewards()))
示例#2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()

    config_file = most_recent_file(FLAGS.experiment_path, r'config.yaml')
    assert config_file
    with open(config_file, 'r') as f:
        config = util.AttrDict(**yaml.load(f.read()))
    logging.info('Config:\n%s', pprint.pformat(config))
    env = gym.make(config.env)
    cls = globals()[config.policy]
    policy = cls(config)
    # Initialize policy
    policy.argmax(np.expand_dims(env.reset(), 0))

    # Load checkpoint.
    # Assuming policy is a keras.Model instance.
    logging.info('policy variables: %s',
                 [v.name for v in policy.trainable_variables])
    ckpt = tf.train.Checkpoint(policy=policy)
    ckpt_file = most_recent_file(FLAGS.experiment_path, r'model.ckpt-[0-9]+')
    if ckpt_file:
        ckpt_file = re.findall('^(.*/model.ckpt-[0-9]+)', ckpt_file)[0]
        logging.info('Checkpoint file: %s', ckpt_file)
        ckpt.restore(ckpt_file).assert_consumed()
    else:
        raise RuntimeError('No checkpoint found')

    summary_writer = tf.summary.create_file_writer(FLAGS.experiment_path,
                                                   flush_millis=10000)

    logging.info('Starting Evaluation')
    it = (range(FLAGS.num_episodes)
          if FLAGS.num_episodes >= 0 else itertools.count())
    for ep in it:
        memory = replay.Memory()
        sample_episode(env, policy, memory, max_episode_length=200)
        logging.info(ep)
        with summary_writer.as_default(), summary.always_record_summaries():
            summary.scalar('return', memory.observed_rewards().sum(), step=ep)
            summary.scalar('length',
                           memory.observed_rewards().shape[-1],
                           step=ep)

    logging.info('DONE')
示例#3
0
文件: DDPG.py 项目: yusme/DDPG
    def __init__(self, env):

        self.env = env
        self.obs_dim = env.observation_space.shape[0]
        self.act_dim = env.action_space.shape[0]

        print "------env ", env.observation_space.shape[
            0], " ", self.act_dim, " obs ", env.action_space.shape[
                0], self.obs_dim
        action_bound = env.action_space.high

        self.avarage_reward = []
        self.loss = []

        self.noise_action = self.noise_act(self.act_dim)

        #self.add_noise=self.noise_action(self.act_dim);

        # ---------------------------
        # - Initialize  Replay Memory
        #  ---------------------------
        # q-learning is off policy algorithmus
        #  since it update the q-value with put
        # making assumption about the actual policy

        self.replay_buffer = replayMemory.Memory(MemorySize, Minibatch_size,
                                                 self.act_dim, self.obs_dim)

        # ---------------------------
        # - Initialize  Q Learning with function approximation,
        # ---------------------------

        # the critic Produce a TD temporal difference and is update from
        # the gradient obtain from the td error signal

        self.critic = Critic.CriticNetwork(self.act_dim, self.obs_dim,
                                           Minibatch_size)
        self.actor_policy = Actor.PolicyNetwork(self.obs_dim, self.act_dim,
                                                Minibatch_size)

        # session = tf.Session();
        session = tf.InteractiveSession()
        session.run(tf.initialize_all_variables())