def test_compute_reward():
    """Tests compute_reward method"""
    env = Bit_Flipping_Environment(5)
    assert env.compute_reward(np.array([0, 0, 0, 1, 0]),
                              np.array([0, 0, 0, 1, 0]),
                              None) == env.reward_for_achieving_goal
    assert env.compute_reward(np.array([1, 1, 1, 1, 1]),
                              np.array([1, 1, 1, 1, 1]),
                              None) == env.reward_for_achieving_goal
    assert env.compute_reward(np.array([0, 0, 0, 0, 0]),
                              np.array([0, 0, 0, 0, 0]),
                              None) == env.reward_for_achieving_goal
    assert env.compute_reward(np.array([1, 1, 1, 1, 1]),
                              np.array([0, 0, 0, 1, 0]),
                              None) == env.step_reward_for_not_achieving_goal
    assert env.compute_reward(np.array([1, 1, 1, 1, 1]),
                              np.array([0, 0, 0, 0, 0]),
                              None) == env.step_reward_for_not_achieving_goal
Пример #2
0
from environments.Four_Rooms_Environment import Four_Rooms_Environment
from drl.agents.hierarchical_agents.SNN_HRL import SNN_HRL
from drl.agents.actor_critic_agents.TD3 import TD3
from drl.agents.Trainer import Trainer
from drl.utilities.data_structures.Config import Config
from drl.agents.DQN_agents.DQN import DQN
import numpy as np
import torch

random.seed(1)
np.random.seed(1)
torch.manual_seed(1)

config = Config()
config.seed = 1
config.environment = Bit_Flipping_Environment(4)
config.num_episodes_to_run = 2000
config.file_to_save_data_results = None
config.file_to_save_results_graph = None
config.visualise_individual_results = False
config.visualise_overall_agent_results = False
config.randomise_random_seed = False
config.runs_per_agent = 1
config.use_GPU = False
config.hyperparameters = {
    "DQN_Agents": {
        "learning_rate": 0.005,
        "batch_size": 64,
        "buffer_size": 40000,
        "epsilon": 0.1,
        "epsilon_decay_rate_denominator": 200,
def test_environment_actions():
    """Tests environment is executing actions correctly"""
    env = Bit_Flipping_Environment(5)
    env.reset()
    env.state = [1, 0, 0, 1, 0, 1, 0, 0, 1, 0]

    env.step(0)
    env.state = env.next_state
    assert env.state == [0, 0, 0, 1, 0, 1, 0, 0, 1, 0]

    env.step(0)
    env.state = env.next_state
    assert env.state == [1, 0, 0, 1, 0, 1, 0, 0, 1, 0]

    env.step(3)
    env.state = env.next_state
    assert env.state == [1, 0, 0, 0, 0, 1, 0, 0, 1, 0]

    env.step(6)
    env.state = env.next_state
    assert env.state == [1, 0, 0, 0, 0, 1, 0, 0, 1, 0]
def test_environment_goal_achievement():
    """Tests environment is registering goal achievement properly"""
    env = Bit_Flipping_Environment(5)
    env.reset()
    env.state = [1, 0, 0, 1, 0, 0, 0, 0, 0, 0]
    env.desired_goal = [0, 0, 0, 0, 0]

    env.step(0)
    assert env.reward == -1
    env.state = env.next_state
    assert env.achieved_goal == [0, 0, 0, 1, 0]

    env.step(2)
    assert env.reward == -1
    env.state = env.next_state
    assert env.achieved_goal == [0, 0, 1, 1, 0]

    env.step(2)
    assert env.reward == -1
    env.state = env.next_state
    assert env.achieved_goal == [0, 0, 0, 1, 0]

    env.step(3)
    assert env.reward == 5