def test_initiation():
    """Tests whether DQN_HER initiates correctly"""
    config.hyperparameters["batch_size"] = 64
    agent = DQN_HER(config)
    agent.reset_game()

    assert agent.ordinary_buffer_batch_size == int(0.2 * 64)
    assert agent.HER_buffer_batch_size == 64 - int(0.2 * 64)

    assert agent.q_network_local.input_dim == 8
    assert agent.q_network_local.output_layers[0].out_features == 4

    assert isinstance(agent.state_dict, dict)

    assert agent.observation.shape[0] == 4
    assert agent.desired_goal.shape[0] == 4
    assert agent.achieved_goal.shape[0] == 4

    assert agent.state.shape[0] == 8
    assert not agent.done
    assert agent.next_state is None
    assert agent.reward is None

    config.hyperparameters["batch_size"] = 3
def test_tracks_changes_from_multiple_actions():
    """Tests that it tracks the changes as a result of actions correctly"""

    agent = DQN_HER(config)
    agent.reset_game()

    for ix in range(4):
        previous_obs = agent.observation
        previous_desired_goal = agent.desired_goal
        previous_achieved_goal = agent.achieved_goal

        agent.action = ix
        agent.conduct_action_in_changeable_goal_envs(agent.action)

        assert agent.next_state.shape[0] == 8
        assert isinstance(agent.next_state_dict, dict)
        assert not all(agent.observation == previous_obs)
        assert not all(agent.achieved_goal == previous_achieved_goal)
        assert all(agent.desired_goal == previous_desired_goal)

        agent.track_changeable_goal_episodes_data()
        agent.save_experience()
        if agent.done: agent.save_alternative_experience()

        agent.state_dict = agent.next_state_dict  # this is to set the state for the next iteration
        agent.state = agent.next_state

    states, actions, rewards, next_states, dones = agent.HER_memory.sample(4)

    assert all(states[1] == torch.Tensor([1.0, 1., 1., 1., 0., 0., 0., 0.]))
    assert all(actions == torch.Tensor([[1.], [0.], [3.], [2.]]))
    assert all(rewards == torch.Tensor([[-1.], [-1.], [4.], [-1.]]))
    assert all(dones == torch.Tensor([[0.], [0.], [1.], [0.]]))
        "alpha_prioritised_replay": 0.6,
        "beta_prioritised_replay": 0.4,
        "incremental_td_error": 1e-8,
        "update_every_n_steps": 3,
        "linear_hidden_units": [20, 20, 20],
        "final_layer_activation": "None",
        "batch_norm": False,
        "gradient_clipping_norm": 5,
        "HER_sample_proportion": 0.8,
        "clip_rewards": False
    }
}

trainer = Trainer(config, [DQN_HER])
config.hyperparameters = config.hyperparameters["DQN_Agents"]
agent = DQN_HER(config)
agent.reset_game()


def test_initiation():
    """Tests whether DQN_HER initiates correctly"""
    config.hyperparameters["batch_size"] = 64
    agent = DQN_HER(config)
    agent.reset_game()

    assert agent.ordinary_buffer_batch_size == int(0.2 * 64)
    assert agent.HER_buffer_batch_size == 64 - int(0.2 * 64)

    assert agent.q_network_local.input_dim == 8
    assert agent.q_network_local.output_layers[0].out_features == 4