def collectData(info):
    i, location, ID = info
    print('Start', ID)
    disablePrint()
    agent = Agent(memory=i)
    env = Environment(render=False).fruitbot
    while i > 0:
        obs = clean(env.reset())
        hn = torch.zeros(2, 1, hidden_size, device=device)
        cn = torch.zeros(2, 1, hidden_size, device=device)
        while i > 0:
            i -= 1
            # hn, cn = hn.detach(), cn.detach()
            act, obs_old, h0, c0, hn, cn = agent.choose(obs, hn, cn)
            obs, rew, done, _ = env.step(act)
            obs = agent.remember(obs_old.detach(), act,
                                 clean(obs).detach(), rew, h0.detach(),
                                 c0.detach(), hn.detach(), cn.detach(),
                                 int(not done))
            env.render()
            if done:
                break
        env.close()
    saveData(agent, location, ID)
    enablePrint()
    print('Done', ID)
    return os.getpid()
예제 #2
0
def collectData(agent):
    print('Start', agent.memory.size)
    disablePrint()
    i = agent.memory.size
    env = Environment(render=False).fruitbot
    while i > 0:
        obs = clean(env.reset())
        hn = torch.zeros(2, 1, hidden_size, device=device)
        cn = torch.zeros(2, 1, hidden_size, device=device)
        while i > 0:
            i -= 1
            # hn, cn = hn.detach(), cn.detach()
            act, obs_old, h0, c0, hn, cn = agent.choose(obs, hn, cn)
            obs, rew, done, _ = env.step(act)
            obs = agent.remember(obs_old.detach(), act,
                                 clean(obs).detach(), rew, h0.detach(),
                                 c0.detach(), hn.detach(), cn.detach(),
                                 int(not done))
            env.render()
            if done:
                break
        env.close()
    enablePrint()
    print('Done')
    return agent.memory.memory
예제 #3
0
class TestEnvironment(unittest.TestCase):
    def setUp(self):
        # An observation space
        observation_space = gym.spaces.Discrete(7)

        # Default reward
        default_reward = Vector([1, 2, 1])

        # Set initial_seed to 0 to testing.
        self.environment = Environment(observation_space=observation_space,
                                       default_reward=default_reward,
                                       seed=0)

    def tearDown(self):
        self.environment = None

    def test_init(self):
        """
        Testing if constructor works
        :return:
        """

        # All agents must be have next attributes
        self.assertTrue(hasattr(self.environment, '_actions'))
        self.assertTrue(hasattr(self.environment, '_icons'))
        self.assertTrue(hasattr(self.environment, 'actions'))
        self.assertTrue(hasattr(self.environment, 'icons'))
        self.assertTrue(hasattr(self.environment, 'action_space'))
        self.assertTrue(hasattr(self.environment, 'observation_space'))
        self.assertTrue(hasattr(self.environment, 'np_random'))
        self.assertTrue(hasattr(self.environment, 'initial_seed'))
        self.assertTrue(hasattr(self.environment, 'initial_state'))
        self.assertTrue(hasattr(self.environment, 'current_state'))
        self.assertTrue(hasattr(self.environment, 'finals'))
        self.assertTrue(hasattr(self.environment, 'obstacles'))
        self.assertTrue(hasattr(self.environment, 'default_reward'))

        # All agents must be have next methods.
        self.assertTrue(hasattr(self.environment, 'step'))
        self.assertTrue(hasattr(self.environment, 'initial_seed'))
        self.assertTrue(hasattr(self.environment, 'reset'))
        self.assertTrue(hasattr(self.environment, 'render'))
        self.assertTrue(hasattr(self.environment, 'next_state'))
        self.assertTrue(hasattr(self.environment, 'is_final'))

        self.assertIsInstance(self.environment.observation_space,
                              gym.spaces.Space)
        self.assertIsInstance(self.environment.action_space, gym.spaces.Space)

        self.assertEqual(self.environment.initial_state,
                         self.environment.current_state)

    def test_icons(self):
        """
        Testing icons property
        :return:
        """
        self.assertEqual(self.environment._icons, self.environment.icons)

    def test_actions(self):
        """
        Testing actions property
        :return:
        """
        self.assertEqual(self.environment._actions, self.environment.actions)

    def test_action_space_length(self):
        pass

    def test_seed(self):
        """
        Testing initial_seed method
        :return:
        """

        self.environment.seed(seed=0)
        n1_1 = self.environment.np_random.randint(0, 10)
        n1_2 = self.environment.np_random.randint(0, 10)

        self.environment.seed(seed=0)
        n2_1 = self.environment.np_random.randint(0, 10)
        n2_2 = self.environment.np_random.randint(0, 10)

        self.assertEqual(n1_1, n2_1)
        self.assertEqual(n1_2, n2_2)

    def test_reset(self):
        """
        Testing reset method
        :return:
        """

        # Set current position to random position
        self.environment.current_state = self.environment.observation_space.sample(
        )

        # Reset environment
        self.environment.reset()

        # Asserts
        self.assertEqual(self.environment.initial_state,
                         self.environment.current_state)

    def test_states(self):
        """
        Testing that all states must be contained into observation space
        :return:
        """
        pass

    def test_reachable_states(self):
        pass

    def test_transition_probability(self):
        pass

    def test_transition_reward(self):
        pass
예제 #4
0
파일: graphs.py 프로젝트: Pozas91/tiadas
def test_agents(environment: Environment,
                hv_reference: Vector,
                variable: str,
                agents_configuration: dict,
                graph_configuration: dict,
                epsilon: float = None,
                alpha: float = None,
                max_steps: int = None,
                states_to_observe: list = None,
                number_of_agents: int = 30,
                gamma: float = 1.,
                solution: list = None,
                initial_q_value: Vector = None,
                evaluation_mechanism: EvaluationMechanism = None):
    """
    If we choose DATA_PER_STATE in graph_configurations, the agent train during `limit` steps, and only get train_data
    in the last steps (ignore `interval`).

    If we choose MEMORY in graph_configurations, the agent train during `limit` steps and take train_data every
    `interval` steps.

    :param initial_q_value:
    :param graph_configuration:
    :param solution:
    :param environment:
    :param hv_reference:
    :param variable:
    :param agents_configuration:
    :param epsilon:
    :param alpha:
    :param max_steps:
    :param states_to_observe:
    :param number_of_agents:
    :param gamma:
    :param evaluation_mechanism:
    :return:
    """

    # Extract graph_types
    graph_types = set(graph_configuration.keys())

    if len(graph_types) > 2:
        print("Isn't recommended more than 2 graphs")

    # Parameters
    if states_to_observe is None:
        states_to_observe = {environment.initial_state}

    complex_states = isinstance(environment.observation_space[0],
                                gym.spaces.Tuple)

    if not complex_states and GraphType.DATA_PER_STATE in graph_types:
        print(
            "This environment has complex states, so DATA_PER_STATE graph is disabled."
        )
        graph_configuration.pop(GraphType.DATA_PER_STATE)

    # Build environment
    env_name = environment.__class__.__name__
    env_name_snake = str_to_snake_case(env_name)

    # File timestamp
    timestamp = int(time.time())

    # Write all information in configuration path
    write_config_file(timestamp=timestamp,
                      number_of_agents=number_of_agents,
                      env_name_snake=env_name_snake,
                      seed=','.join(map(str, range(number_of_agents))),
                      epsilon=epsilon,
                      alpha=alpha,
                      gamma=gamma,
                      max_steps=max_steps,
                      variable=variable,
                      agents_configuration=agents_configuration,
                      graph_configuration=graph_configuration,
                      evaluation_mechanism=evaluation_mechanism)

    # Create graphs structure
    graphs, graphs_info = initialize_graph_data(
        graph_types=graph_types, agents_configuration=agents_configuration)

    # Show information
    print('Environment: {}'.format(env_name))

    for graph_type in graph_types:

        # Extract interval and limit
        interval = graph_configuration[graph_type].get('interval', 1)
        limit = graph_configuration[graph_type]['limit']

        # Show information
        print(('\t' * 1) +
              "Graph type: {} - [{}/{}]".format(graph_type, limit, interval))

        # Set interval to get train_data
        Agent.interval_to_get_data = interval

        # Execute a iteration with different initial_seed for each agent indicate
        for seed in range(number_of_agents):

            # Show information
            print(('\t' * 2) + "Execution: {}".format(seed + 1))

            # For each configuration
            for agent_type in agents_configuration:

                # Show information
                print(('\t' * 3) + 'Agent: {}'.format(agent_type.value))

                # Extract configuration for that agent
                for configuration in agents_configuration[agent_type].keys():

                    # Show information
                    print(
                        ('\t' * 4) + '{}: {}'.format(variable, configuration),
                        end=' ')

                    # Mark of time
                    t0 = time.time()

                    # Reset environment
                    environment.reset()
                    environment.seed(seed=seed)

                    # Variable parameters
                    parameters = {
                        'epsilon': epsilon,
                        'alpha': alpha,
                        'gamma': gamma,
                        'max_steps': max_steps,
                        'evaluation_mechanism': evaluation_mechanism,
                        'initial_value': initial_q_value
                    }

                    if variable == 'decimal_precision':
                        Vector.set_decimal_precision(
                            decimal_precision=configuration)
                    else:
                        # Modify current configuration
                        parameters.update({variable: configuration})

                    agent, v_s_0 = train_agent_and_get_v_s_0(
                        agent_type=agent_type,
                        environment=environment,
                        graph_type=graph_type,
                        graph_types=graph_types,
                        hv_reference=hv_reference,
                        limit=limit,
                        seed=seed,
                        parameters=parameters,
                        states_to_observe=states_to_observe)

                    print('-> {:.2f}s'.format(time.time() - t0))

                    train_data = dict()

                    if agent_type is AgentType.PQL and graph_type is GraphType.DATA_PER_STATE:
                        train_data.update({
                            'vectors': {
                                state: {
                                    action: agent.q_set(state=state,
                                                        action=action)
                                    for action in agent.nd[state].keys()
                                }
                                for state in agent.nd.keys()
                            }
                        })

                    # Order vectors by origin Vec(0) nearest
                    train_data.update({
                        'v_s_0':
                        Vector.order_vectors_by_origin_nearest(vectors=v_s_0),
                        # 'q': agent.q,
                        # 'v': agent.v
                    })

                    # Write vectors found into path
                    dumps_train_data(
                        timestamp=timestamp,
                        seed=seed,
                        env_name_snake=env_name_snake,
                        train_data=train_data,
                        variable=variable,
                        agent_type=agent_type,
                        configuration=configuration,
                        evaluation_mechanism=evaluation_mechanism,
                        columns=environment.observation_space[0].n)

                    # Update graphs
                    update_graphs(graphs=graphs,
                                  agent=agent,
                                  graph_type=graph_type,
                                  configuration=str(configuration),
                                  agent_type=agent_type,
                                  states_to_observe=states_to_observe,
                                  graphs_info=graphs_info,
                                  solution=solution)

    prepare_data_and_show_graph(timestamp=timestamp,
                                env_name=env_name,
                                env_name_snake=env_name_snake,
                                graphs=graphs,
                                number_of_agents=number_of_agents,
                                agents_configuration=agents_configuration,
                                alpha=alpha,
                                epsilon=epsilon,
                                gamma=gamma,
                                graph_configuration=graph_configuration,
                                max_steps=max_steps,
                                initial_state=environment.initial_state,
                                variable=variable,
                                graphs_info=graphs_info,
                                evaluation_mechanism=evaluation_mechanism,
                                solution=solution)