Ejemplo n.º 1
0
def test_integration_with_keras():
    TestTransition = collections.namedtuple('TestTransition', ['observation'])

    # Just a smoke test, that nothing errors out.
    n_transitions = 10
    obs_shape = (4, )
    network_sig = data.NetworkSignature(
        input=data.TensorSignature(shape=obs_shape),
        output=data.TensorSignature(shape=(1, )),
    )
    trainer = supervised.SupervisedTrainer(
        network_signature=network_sig,
        target=supervised.target_solved,
        batch_size=2,
        n_steps_per_epoch=3,
        replay_buffer_capacity=n_transitions,
    )
    trainer.add_episode(
        data.Episode(
            transition_batch=TestTransition(
                observation=np.zeros((n_transitions, ) + obs_shape), ),
            return_=123,
            solved=False,
        ))
    network = keras.KerasNetwork(network_signature=network_sig)
    trainer.train_epoch(network)
Ejemplo n.º 2
0
def construct_episodes(actions, rewards, **kwargs):
    """Constructs episodes from actions and rewards nested lists.

    Args:
        actions (list): Each episode actions, example:
        [
            [a00, a01, a02, ...], # Actions in the first episode.
            [a10, a11, a12, ...], # Actions in the second episode.
            ...
        ]
        rewards (list): Each episode rewards, example:
        [
            [r00, r01, r02, ...], # Rewards in the first episode.
            [r10, r11, r12, ...], # Rewards in the second episode.
            ...
        ]
        **kwargs (dict): Keyword arguments passed to Episode.

    Return:
        list of Episodes where:
         - Transition observations and next observations are set to None.
         - Done flag is True only for the last transition in the episode.
         - Episode.return_ is calculated as an undiscounted sum of rewards.
    """
    episodes = []
    for acts, rews in zip(actions, rewards):
        transitions = [
            data.Transition(None, act, rew, False, None, {}, {})
            for act, rew in zip(acts[:-1], rews[:-1])
        ]
        transitions.append(
            data.Transition(None, acts[-1], rews[-1], True, None, {}, {}))
        transition_batch = data.nested_stack(transitions)
        episodes.append(data.Episode(transition_batch, sum(rews), **kwargs))
    return episodes
Ejemplo n.º 3
0
def construct_episodes(actions, rewards):
    """Constructs episodes from actions and rewards nested lists."""
    episodes = []
    for acts, rews in zip(actions, rewards):
        transitions = [
            # TODO(koz4k): Initialize using kwargs.
            data.Transition(None, act, rew, False, None, {})
            for act, rew in zip(acts[:-1], rews[:-1])
        ]
        transitions.append(
            data.Transition(None, acts[-1], rews[-1], True, None, {}))
        transition_batch = data.nested_stack(transitions)
        episodes.append(data.Episode(transition_batch, sum(rews)))
    return episodes
Ejemplo n.º 4
0
def test_multiple_targets():
    TestTransition = collections.namedtuple('TestTransition',
                                            ['observation', 'agent_info'])

    network_sig = data.NetworkSignature(
        input=data.TensorSignature(shape=(1, )),
        # Two outputs.
        output=(
            data.TensorSignature(shape=(1, )),
            data.TensorSignature(shape=(2, )),
        ),
    )
    trainer = supervised.SupervisedTrainer(
        network_signature=network_sig,
        # Two targets.
        target=(supervised.target_solved, supervised.target_qualities),
        batch_size=1,
        n_steps_per_epoch=1,
        replay_buffer_capacity=1,
    )
    trainer.add_episode(
        data.Episode(
            transition_batch=TestTransition(
                observation=np.zeros((1, 1)),
                agent_info={'qualities': np.zeros((1, 2))},
            ),
            return_=123,
            solved=False,
        ))

    class TestNetwork(core.DummyNetwork):
        """Mock class."""
        def train(self,
                  data_stream,
                  n_steps,
                  epoch,
                  validation_data_stream=None):
            np.testing.assert_equal(
                list(data_stream()),
                [
                    testing.zero_pytree(
                        (network_sig.input, network_sig.output),
                        shape_prefix=(1, ))
                ],
            )

            return {}

    trainer.train_epoch(TestNetwork(network_sig))
Ejemplo n.º 5
0
def test_integration_with_keras():
    # Just a smoke test, that nothing errors out.
    n_transitions = 10
    obs_shape = (4, )
    trainer = supervised.SupervisedTrainer(
        input_shape=obs_shape,
        target_fn=supervised.target_solved,
        batch_size=2,
        n_steps_per_epoch=3,
        replay_buffer_capacity=n_transitions,
    )
    trainer.add_episode(
        data.Episode(
            transition_batch=_TestTransition(
                observation=np.zeros((n_transitions, ) + obs_shape), ),
            return_=123,
            solved=False,
        ))
    network = keras.KerasNetwork(input_shape=obs_shape)
    trainer.train_epoch(network)
Ejemplo n.º 6
0
    def solve(self, env, epoch=None, init_state=None, time_limit=None):
        """Solves a given environment using OnlineAgent.act().

        Args:
            env (gym.Env): Environment to solve.
            epoch (int): Current training epoch or None if no training.
            init_state (object): Reset the environment to this state.
                If None, then do normal gym.Env.reset().
            time_limit (int or None): Maximum number of steps to make on the
                solved environment. None means no time limit.

        Yields:
            Network-dependent: A stream of Network inputs requested for
            inference.

        Returns:
            data.Episode: Episode object containing a batch of collected
            transitions and the return for the episode.
        """
        yield from super().solve(env, epoch, init_state, time_limit)

        self._epoch = epoch

        model_env = env

        if time_limit is not None:
            # Add the TimeLimitWrapper _after_ passing the model env to the
            # agent, so the states cloned/restored by the agent do not contain
            # the number of steps made so far - this would break state lookup
            # in some Agents.
            env = envs.TimeLimitWrapper(env, time_limit)

        if init_state is None:
            # Model-free case...
            full_observation = env.reset()
            observation = np.concatenate([
                full_observation['observation'],
                full_observation['desired_goal']
            ],
                                         axis=-1)
        else:
            # Model-based case...
            observation = env.restore_state(init_state)
        # print('init observation', observation)

        yield from self.reset(model_env, observation)
        #for x in self.reset(model_env, observation):
        ##print(x)
        #yield np.concatenate([x['observation'], x['desired_goal']], axis=-1)

        transitions = []
        done = False
        info = {}
        places = {tuple(observation.flatten())}
        while not done:
            # Forward network prediction requests to BatchStepper.
            # print("solving...")
            #print(observation)
            (action, agent_info) = yield from self.act(observation)
            # print("has action!")
            # TODO
            (full_next_observation, reward, done, info) = env.step(action)
            next_observation = np.concatenate([
                full_next_observation['observation'],
                full_next_observation['desired_goal']
            ],
                                              axis=-1)
            places.add(tuple(next_observation.flatten()))

            transitions.append(
                data.Transition(
                    observation=full_observation,
                    action=action,
                    reward=reward,
                    done=done,
                    next_observation=full_next_observation,
                    agent_info=agent_info,
                ))
            full_observation = full_next_observation
            observation = next_observation

        return_ = sum(transition.reward for transition in transitions)
        transitions = self.postprocess_transitions(transitions)

        solved = info['solved'] if 'solved' in info else None
        truncated = (info['TimeLimit.truncated']
                     if 'TimeLimit.truncated' in info else None)
        transition_batch = data.nested_stack(transitions)

        info = {'move_diversity': len(places)}
        # neptune_logger('move diversity', len(places))
        # sys.exit(0)
        return data.Episode(
            transition_batch=transition_batch,
            return_=return_,
            solved=solved,
            truncated=truncated,
            info=info,
        )
Ejemplo n.º 7
0
    def solve(self, env, epoch=None, init_state=None, time_limit=None):
        """Solves a given environment using OnlineAgent.act().

        Args:
            env (gym.Env): Environment to solve.
            epoch (int): Current training epoch or None if no training.
            init_state (object): Reset the environment to this state.
                If None, then do normal gym.Env.reset().
            time_limit (int or None): Maximum number of steps to make on the
                solved environment. None means no time limit.

        Yields:
            Network-dependent: A stream of Network inputs requested for
            inference.

        Returns:
            data.Episode: Episode object containing a batch of collected
            transitions and the return for the episode.
        """
        yield from super().solve(env, epoch, init_state, time_limit)

        self._epoch = epoch

        model_env = env

        if time_limit is not None:
            # Add the TimeLimitWrapper _after_ passing the model env to the
            # agent, so the states cloned/restored by the agent do not contain
            # the number of steps made so far - this would break state lookup
            # in some Agents.
            env = envs.TimeLimitWrapper(env, time_limit)

        if init_state is None:
            # Model-free case...
            observation = env.reset()
        else:
            # Model-based case...
            observation = env.restore_state(init_state)

        yield from self.reset(model_env, observation)

        for callback in self._callbacks:
            callback.on_episode_begin(env, observation, epoch)

        transitions = []
        done = False
        info = {}
        while not done:
            # Forward network prediction requests to BatchStepper.
            (action, agent_info) = yield from self.act(observation)
            (next_observation, reward, done, info) = env.step(action)

            for callback in self._callbacks:
                callback.on_real_step(agent_info, action, next_observation,
                                      reward, done)

            transitions.append(
                data.Transition(
                    observation=observation,
                    action=action,
                    reward=reward,
                    done=done,
                    next_observation=next_observation,
                    agent_info=agent_info,
                ))
            observation = next_observation

        for callback in self._callbacks:
            callback.on_episode_end()

        transitions = self.postprocess_transitions(transitions)

        return_ = sum(transition.reward for transition in transitions)
        solved = info['solved'] if 'solved' in info else None
        truncated = (info['TimeLimit.truncated']
                     if 'TimeLimit.truncated' in info else None)
        transition_batch = data.nested_stack(transitions)
        additional_info = info[
            'additional_info'] if 'additional_info' in info else None
        return data.Episode(transition_batch=transition_batch,
                            return_=return_,
                            solved=solved,
                            truncated=truncated,
                            additional_info=additional_info)
Ejemplo n.º 8
0
    def solve(self, env, epoch=None, init_state=None, time_limit=None):
        yield from super().solve(env, epoch, init_state, time_limit)

        self._epoch = epoch

        model_env = env

        if time_limit is not None:
            env = envs.TimeLimitWrapper(env, time_limit)

        if init_state is None:

            observation = env.reset()
        else:

            observation = env.restore_state(init_state)

        yield from self.reset(model_env, observation)

        for callback in self._callbacks:
            callback.on_episode_begin(env, observation, epoch)

        transitions = []
        done = False
        info = {}
        while not done:

            (action, agent_info) = yield from self.act(observation)
            (next_observation, reward, done, info) = env.step(action)

            for callback in self._callbacks:
                callback.on_real_step(agent_info, action, next_observation,
                                      reward, done)

            transitions.append(
                data.Transition(
                    observation=observation,
                    action=action,
                    reward=reward,
                    done=done,
                    next_observation=next_observation,
                    agent_info=agent_info,
                ))
            observation = next_observation

        for callback in self._callbacks:
            callback.on_episode_end()

        transitions = self.postprocess_transitions(transitions)

        return_ = sum(transition.reward for transition in transitions)
        solved = info['solved'] if 'solved' in info else None
        truncated = (info['TimeLimit.truncated']
                     if 'TimeLimit.truncated' in info else None)
        transition_batch = data.nested_stack(transitions)
        action_space_size = space.max_size(model_env.action_space)
        return data.Episode(transition_batch=transition_batch,
                            return_=return_,
                            solved=solved,
                            truncated=truncated,
                            action_space_size=action_space_size)