Ejemplo n.º 1
0
    def init_plot(self, env, policy):
        if not Plotter.enable:
            return
        if not (self._process and self._queue):
            self.init_worker()

        # Needed in order to draw glfw window on the main thread
        if ('Darwin' in platform.platform()):
            rollout(env,
                    policy,
                    max_path_length=np.inf,
                    animated=True,
                    speedup=5)

        self._queue.put(Message(op=Op.UPDATE, args=(env, policy), kwargs=None))
Ejemplo n.º 2
0
    def _obtain_evaluation_samples(self,
                                   env,
                                   num_trajs=100,
                                   max_path_length=1000):
        """Sample the policy for 10 trajectories and return average values.

        Args:
            env (metarl.envs.MetaRLEnv): The environement used to obtain
                trajectories.
            num_trajs (int): Number of trajectories.
            max_path_length (int): Number of maximum steps in one batch.

        Returns:
            TrajectoryBatch: Evaluation trajectories, representing the best
                current performance of the algorithm.

        """
        paths = []

        for _ in range(num_trajs):
            path = rollout(env,
                           self.policy,
                           max_path_length=max_path_length,
                           deterministic=True)
            paths.append(path)
        return TrajectoryBatch.from_trajectory_list(self.env_spec, paths)
Ejemplo n.º 3
0
    def _worker_start(self):
        env = None
        policy = None
        max_length = None
        initial_rollout = True
        try:
            # Each iteration will process ALL messages currently in the
            # queue
            while True:
                msgs = {}
                # If true, block and yield processor
                if initial_rollout:
                    msg = self._queue.get()
                    msgs[msg.op] = msg
                    # Only fetch the last message of each type
                    while not self._queue.empty():
                        msg = self._queue.get()
                        msgs[msg.op] = msg
                else:
                    # Only fetch the last message of each type
                    while not self._queue.empty():
                        msg = self._queue.get_nowait()
                        msgs[msg.op] = msg

                if Op.STOP in msgs:
                    break
                elif Op.UPDATE in msgs:
                    env, policy = msgs[Op.UPDATE].args
                elif Op.DEMO in msgs:
                    param_values, max_length = msgs[Op.DEMO].args
                    policy.set_param_values(param_values)
                    initial_rollout = False
                    rollout(env,
                            policy,
                            max_path_length=max_length,
                            animated=True,
                            speedup=5)
                else:
                    if max_length:
                        rollout(env,
                                policy,
                                max_path_length=max_length,
                                animated=True,
                                speedup=5)
        except KeyboardInterrupt:
            pass
 def test_max_path_length(self):
     # pylint: disable=unsubscriptable-object
     path = utils.rollout(self.env, self.policy, max_path_length=3)
     assert path['observations'].shape[0] == 3
     assert path['actions'].shape[0] == 3
     assert path['rewards'].shape[0] == 3
     agent_info = [
         path['agent_infos'][k]
         for k in self.policy.distribution.dist_info_keys
     ]
     assert agent_info[0].shape[0] == 3
     # dummy is the env_info_key
     assert path['env_infos']['dummy'].shape[0] == 3
Ejemplo n.º 5
0
def _worker_collect_one_path(g, max_path_length, scope=None):
    g = _get_scoped_g(g, scope)
    path = rollout(g.env, g.policy, max_path_length=max_path_length)
    return path, len(path['rewards'])
 def test_does_flatten(self):
     path = utils.rollout(self.env, self.policy, max_path_length=5)
     assert path['observations'][0].shape == (16, )
     assert path['actions'][0].shape == (2, 2)
Ejemplo n.º 7
0
                             "(or 'y' or 'n').\n")


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('file', type=str, help='path to the snapshot file')
    parser.add_argument('--max_path_length',
                        type=int,
                        default=1000,
                        help='Max length of rollout')
    parser.add_argument('--speedup', type=float, default=1, help='Speedup')
    args = parser.parse_args()

    # If the snapshot file use tensorflow, do:
    # import tensorflow as tf
    # with tf.compat.v1.Session():
    #     [rest of the code]
    with tf.compat.v1.Session() as sess:
        data = joblib.load(args.file)
        policy = data['algo'].policy
        env = data['env']
        while True:
            path = rollout(env,
                           policy,
                           max_path_length=args.max_path_length,
                           animated=True,
                           speedup=args.speedup)
            if not query_yes_no('Continue simulation?'):
                break