def test_y_range():
    """Tests whether setting a y range works correctly"""
    for _ in range(100):
        val1 = random.random() - 3.0*random.random()
        val2 = random.random() + 2.0*random.random()
        lower_bound = min(val1, val2)
        upper_bound = max(val1, val2)
        rnn = RNN(layers_info=[["lstm", 20], ["gru", 5], ["lstm", 25]],
                           hidden_activations="relu", y_range=(lower_bound, upper_bound),
                           initialiser="xavier", input_dim=22)
        random_data = torch.randn((10, 11, 22))
        out = rnn.forward(random_data)
        out = out.reshape(1, -1).squeeze()
        assert torch.sum(out > lower_bound).item() == 25*10, "lower {} vs. {} ".format(lower_bound, out)
        assert torch.sum(out < upper_bound).item() == 25*10, "upper {} vs. {} ".format(upper_bound, out)
def test_check_input_data_into_forward_once():
    """Tests that check_input_data_into_forward_once method only runs once"""
    rnn = RNN(layers_info=[["lstm", 20], ["gru", 5], ["lstm", 25]],
                       hidden_activations="relu", input_dim=5,
                       output_activation="relu", initialiser="xavier")

    data_not_to_throw_error = torch.randn((1, 4, 5))
    data_to_throw_error = torch.randn((1, 2, 20))

    with pytest.raises(AssertionError):
        rnn.forward(data_to_throw_error)
    with pytest.raises(RuntimeError):
        rnn.forward(data_not_to_throw_error)
        rnn.forward(data_to_throw_error)
def test_output_activation_return_return_final_seq_only_off():
    """Tests whether network outputs data that has gone through correct activation function"""
    RANDOM_ITERATIONS = 20
    input_dim = 100
    for _ in range(RANDOM_ITERATIONS):
        data = torch.randn((25, 10, 100))
        RNN_instance = RNN(layers_info=[["lstm", 20], ["gru", 5], ["linear", 10], ["linear", 3]],
                           hidden_activations="relu", input_dim=input_dim, return_final_seq_only=False,
                           output_activation="relu", initialiser="xavier", batch_norm=True)
        out = RNN_instance.forward(data)
        assert all(out.reshape(1, -1).squeeze() >= 0)

        RNN_instance = RNN(layers_info=[["lstm", 20], ["gru", 5]],
                           hidden_activations="relu",  input_dim=input_dim, return_final_seq_only=False,
                           output_activation="relu", initialiser="xavier")
        out = RNN_instance.forward(data)
        assert all(out.reshape(1, -1).squeeze() >= 0)

        RNN_instance = RNN(layers_info=[["lstm", 20], ["gru", 5], ["linear", 10], ["linear", 3]],
                           hidden_activations="relu", input_dim=input_dim, return_final_seq_only=False,
                           output_activation="relu", initialiser="xavier")
        out = RNN_instance.forward(data)
        assert all(out.reshape(1, -1).squeeze() >= 0)

        RNN_instance = RNN(layers_info=[["lstm", 20], ["gru", 5], ["linear", 10], ["linear", 3]],
                           hidden_activations="relu", input_dim=input_dim, return_final_seq_only=False,
                           output_activation="sigmoid", initialiser="xavier")
        out = RNN_instance.forward(data)
        assert all(out.reshape(1, -1).squeeze() >= 0)
        assert all(out.reshape(1, -1).squeeze() <= 1)
        summed_result = torch.sum(out, dim=2)
        assert all(summed_result.reshape(1, -1).squeeze() != 1.0)


        RNN_instance = RNN(layers_info=[["lstm", 20], ["gru", 5], ["linear", 10], ["linear", 3]],
                           hidden_activations="relu", input_dim=input_dim, return_final_seq_only=False,
                           output_activation="softmax", initialiser="xavier")
        out = RNN_instance.forward(data)
        assert all(out.reshape(1, -1).squeeze() >= 0)
        assert all(out.reshape(1, -1).squeeze() <= 1)
        summed_result = torch.sum(out, dim=2)
        summed_result = summed_result.reshape(1, -1).squeeze()
        summed_result = torch.round( (summed_result * 10 ** 5) / (10 ** 5))
        assert all( summed_result == 1.0)

        RNN_instance = RNN(layers_info=[["lstm", 20], ["gru", 5], ["lstm", 25]],
                           hidden_activations="relu", input_dim=input_dim, return_final_seq_only=False,
                           output_activation="softmax", initialiser="xavier")
        out = RNN_instance.forward(data)
        assert all(out.reshape(1, -1).squeeze() >= 0)
        assert all(out.reshape(1, -1).squeeze() <= 1)
        summed_result = torch.sum(out, dim=2)
        summed_result = summed_result.reshape(1, -1).squeeze()
        summed_result = torch.round( (summed_result * 10 ** 5) / (10 ** 5))



        assert all( summed_result == 1.0)

        RNN_instance = RNN(layers_info=[["lstm", 20], ["gru", 5], ["lstm", 25]],
                           hidden_activations="relu", input_dim=input_dim, return_final_seq_only=False,
                           initialiser="xavier")
        out = RNN_instance.forward(data)
        assert not all(out.reshape(1, -1).squeeze() >= 0)

        assert not all(out.reshape(1, -1).squeeze() <= 0)
        summed_result = torch.sum(out, dim=2)
        summed_result = summed_result.reshape(1, -1).squeeze()
        summed_result = torch.round( (summed_result * 10 ** 5) / (10 ** 5))
        assert not all( summed_result == 1.0)

        RNN_instance = RNN(layers_info=[["lstm", 20], ["gru", 5], ["lstm", 25], ["linear", 8]],
                           hidden_activations="relu", input_dim=input_dim, return_final_seq_only=False,
                           initialiser="xavier")
        out = RNN_instance.forward(data)
        assert not all(out.reshape(1, -1).squeeze() >= 0)
        assert not all(out.reshape(1, -1).squeeze() <= 0)
        summed_result = torch.sum(out, dim=2)
        summed_result = summed_result.reshape(1, -1).squeeze()
        summed_result = torch.round( (summed_result * 10 ** 5) / (10 ** 5))
        assert not all( summed_result == 1.0)