Beispiel #1
0
def test_phenotype_get_prediction():
    input_nodes = 4
    hidden_layers_nodes = [6, 8]
    output_nodes = 3

    number_of_nn_weights = get_number_of_nn_weights(input_nodes,
                                                    hidden_layers_nodes,
                                                    output_nodes)
    weight_lower_threshold = -1
    weight_upper_threshold = 1

    random_simple_genotype = SimpleGenotype.get_random_genotype(
        number_of_nn_weights, weight_lower_threshold, weight_upper_threshold)

    phenotype = Phenotype.get_phenotype_from_genotype(random_simple_genotype,
                                                      input_nodes,
                                                      hidden_layers_nodes,
                                                      output_nodes)

    input_values = array([[1, 0, 0, -1]])

    prediction = Phenotype.get_prediction(phenotype, input_values)

    assert type(prediction) is ndarray
    assert prediction.shape == (1, 3)
Beispiel #2
0
def test_phenotype_get_phenotype_for_genotype():
    input_nodes = 4
    hidden_layers_nodes = [6, 8]
    output_nodes = 3

    number_of_nn_weights = get_number_of_nn_weights(input_nodes,
                                                    hidden_layers_nodes,
                                                    output_nodes)
    weight_lower_threshold = -1
    weight_upper_threshold = 1

    random_simple_genotype = SimpleGenotype.get_random_genotype(
        number_of_nn_weights, weight_lower_threshold, weight_upper_threshold)

    phenotype = Phenotype.get_phenotype_from_genotype(random_simple_genotype,
                                                      input_nodes,
                                                      hidden_layers_nodes,
                                                      output_nodes)

    assert type(phenotype) is Phenotype
    assert len(phenotype.layers) == 3
    assert phenotype.input_nodes == input_nodes
    assert phenotype.hidden_layers_nodes == hidden_layers_nodes
    assert phenotype.output_nodes == output_nodes

    for element in phenotype.layers:
        assert type(element) is ndarray
Beispiel #3
0
def test_get_full_game_representation_strategy():
    def game_representation_strategy(game):
        return Game.get_full_game_representation_strategy(game)

    snake_length = 3
    width = 8
    height = 8
    input_nodes = width * height + 4
    hidden_layer_nodes = [64]
    output_nodes = 3
    number_of_nn_weights = get_number_of_nn_weights(input_nodes,
                                                    hidden_layer_nodes,
                                                    output_nodes)
    weight_lower_threshold = -1
    weight_upper_threshold = 1

    sample_genotype = SimpleGenotype.get_random_genotype(
        number_of_nn_weights, weight_lower_threshold, weight_upper_threshold)
    sample_phenotype = Phenotype(sample_genotype.weights, input_nodes,
                                 hidden_layer_nodes, output_nodes)
    sample_game = Game(width, height, sample_phenotype, 777,
                       game_representation_strategy, snake_length)

    game_representation = Game.get_full_game_representation_strategy(
        sample_game)
    assert len(game_representation) == 8 * 8 + 4
    assert game_representation[36] == 1
    assert game_representation[44] == 1
    assert game_representation[52] == 1
Beispiel #4
0
def test_simple_genotype_get_random_genotypes():
    number_of_genotypes = 20

    random_simple_genotypes = SimpleGenotype.get_random_genotypes(
        number_of_genotypes, 20, -1, 1)
    assert len(random_simple_genotypes) == number_of_genotypes
    for element in random_simple_genotypes:
        assert type(element) is SimpleGenotype
Beispiel #5
0
def test_simple_genotype_get_mutated_genotype():
    number_of_nn_weights = 20
    weight_lower_threshold = -1
    weight_upper_threshold = 1
    mutation_mean = 0
    mutation_standard_deviation = 0.2

    random_simple_genotype = SimpleGenotype.get_random_genotype(
        number_of_nn_weights, weight_lower_threshold, weight_upper_threshold)
    mutated_simple_genotype = SimpleGenotype.get_mutated_genotype(
        random_simple_genotype, 0, 0.3)

    assert random_simple_genotype != mutated_simple_genotype
    assert len(mutated_simple_genotype.weights) == number_of_nn_weights

    for index in range(number_of_nn_weights):
        assert random_simple_genotype.weights[
            index] != mutated_simple_genotype.weights[index]
Beispiel #6
0
def test_simple_genotype_get_random_genotype():
    number_of_nn_weights = 20
    weight_lower_threshold = -1
    weight_upper_threshold = 1

    random_simple_genotype = SimpleGenotype.get_random_genotype(
        number_of_nn_weights, weight_lower_threshold, weight_upper_threshold)

    assert type(random_simple_genotype) is SimpleGenotype
    assert len(random_simple_genotype.weights) == 20
    for element in random_simple_genotype.weights:
        assert weight_lower_threshold <= element <= weight_upper_threshold
Beispiel #7
0
        def mutation_strategy(genotype):
            if arguments.genotype == simple_genotype_choice:
                return SimpleGenotype.get_mutated_genotype(
                    genotype, arguments.mutation_mean,
                    arguments.mutation_standard_deviation)

            if arguments.genotype == uncorrelated_one_step_size_genotype_choice:
                return UncorrelatedOneStepSizeGenotype.get_mutated_genotype(
                    genotype, arguments.tau1)

            if arguments.genotype == uncorrelated_n_step_size_genotype_choice:
                return UncorrelatedNStepSizeGenotype.get_mutated_genotype(
                    genotype, arguments.tau1, arguments.tau2)
Beispiel #8
0
def test_game_max_points_threshold():
    def game_representation_strategy(game):
        return Game.get_full_game_representation_strategy(game)

    snake_length = 5
    width = 32
    height = 18
    snack_eaten_points = 2

    input_nodes = width * height + 4
    hidden_layer_nodes = [64]
    output_nodes = 3
    number_of_nn_weights = get_number_of_nn_weights(input_nodes,
                                                    hidden_layer_nodes,
                                                    output_nodes)
    weight_lower_threshold = -1
    weight_upper_threshold = 1
    max_points_threshold = 200
    min_points_threshold = -10

    sample_genotype = SimpleGenotype.get_random_genotype(
        number_of_nn_weights, weight_lower_threshold, weight_upper_threshold)
    sample_phenotype = Phenotype(sample_genotype.weights, input_nodes,
                                 hidden_layer_nodes, output_nodes)

    sample_game = Game(width,
                       height,
                       sample_phenotype,
                       777,
                       game_representation_strategy,
                       snake_length,
                       snack_eaten_points,
                       max_points_threshold=max_points_threshold,
                       min_points_threshold=min_points_threshold)
    sample_game.score = 200

    next_game_state = Game.get_next_game(sample_game)
    assert next_game_state.status == GameStatus.ENDED

    next_game_state.score = -10
    next_game_state = Game.get_next_game(next_game_state)
    assert next_game_state.status == GameStatus.ENDED
Beispiel #9
0
def test_game_initialization():
    def game_representation_strategy(game):
        return Game.get_full_game_representation_strategy(game)

    snake_length = 5
    width = 32
    height = 18
    seed = 777

    input_nodes = width * height + 4
    hidden_layer_nodes = [64]
    output_nodes = 3
    number_of_nn_weights = get_number_of_nn_weights(input_nodes,
                                                    hidden_layer_nodes,
                                                    output_nodes)
    weight_lower_threshold = -1
    weight_upper_threshold = 1

    sample_genotype = SimpleGenotype.get_random_genotype(
        number_of_nn_weights, weight_lower_threshold, weight_upper_threshold)
    sample_phenotype = Phenotype(sample_genotype.weights, input_nodes,
                                 hidden_layer_nodes, output_nodes)

    sample_game = Game(width, height, sample_phenotype, seed,
                       game_representation_strategy, snake_length)

    assert len(sample_game.snake) == snake_length

    correct_snake_blocks = [(16, 9), (17, 9), (18, 9), (19, 9), (20, 9)]
    assert sample_game.snake == correct_snake_blocks
    assert sample_game.snack_perspective == (21, 9)

    assert sample_game.snack not in correct_snake_blocks
    assert 0 <= sample_game.snack[0] <= width
    assert 0 <= sample_game.snack[1] <= height

    another_game = Game(width, height, sample_phenotype, seed,
                        game_representation_strategy, snake_length)
    assert sample_game.snack == another_game.snack
Beispiel #10
0
def test_snake_points_assignment():
    def game_representation_strategy(game):
        return Game.get_full_game_representation_strategy(game)

    snake_length = 5
    width = 32
    height = 18
    snack_eaten_points = 2

    input_nodes = width * height + 4
    hidden_layer_nodes = [64]
    output_nodes = 3
    number_of_nn_weights = get_number_of_nn_weights(input_nodes,
                                                    hidden_layer_nodes,
                                                    output_nodes)
    weight_lower_threshold = -1
    weight_upper_threshold = 1

    sample_genotype = SimpleGenotype.get_random_genotype(
        number_of_nn_weights, weight_lower_threshold, weight_upper_threshold)
    sample_phenotype = Phenotype(sample_genotype.weights, input_nodes,
                                 hidden_layer_nodes, output_nodes)

    sample_game = Game(width, height, sample_phenotype, 777,
                       game_representation_strategy, snake_length,
                       snack_eaten_points)

    # Hack snack position
    # Initial snake blocks are [(16, 9), (17, 9), (18, 9), (19, 9), (20, 9)]
    # So we put snack on snake head
    sample_game.snack = (16, 9)

    next_game_object = Game.get_next_game(sample_game)

    assert next_game_object.score == snack_eaten_points
    assert next_game_object.snack != (16, 9)
    assert next_game_object.snack is not None
    assert len(next_game_object.snake) == snake_length + 1
    assert next_game_object.snack_perspective is not None
Beispiel #11
0
def get_evolution_summary(arguments,
                          input_nodes,
                          output_nodes,
                          memory_consumption_probe=100):
    output = [
        str.format('# Basic summary'),
        str.format('Input nodes: {}', input_nodes),
        str.format('Intput nodes: {}', output_nodes),
        str.format('Genotype: {}', genotype_lookup[arguments.genotype]),
        str.format('Hidden layer nodes: {}', arguments.hidden_layer_nodes),
        str.format('Weight lower threshold: {}',
                   arguments.weight_lower_threshold),
        str.format('Weight upper threshold: {}',
                   arguments.weight_upper_threshold),
        str.format('Population size: {}', arguments.population_size),
        str.format('Tournament size: {}', arguments.tournament_size),
        str.format('Duration (hours): {}', arguments.duration),
        str.format('Use bias: {}', arguments.bias),
        str.format('# Genotype specific summary')
    ]

    if arguments.genotype == simple_genotype_choice:
        output.append(str.format('Mutation mean: {}', arguments.mutation_mean))
        output.append(
            str.format('Mutation standard deviation: {}',
                       arguments.mutation_standard_deviation))

    if arguments.genotype == uncorrelated_one_step_size_genotype_choice or arguments.genotype == uncorrelated_n_step_size_genotype_choice:
        output.append(
            str.format('Mutation step size lower threshold: {}',
                       arguments.mutation_step_size_lower_threshold))
        output.append(
            str.format('Mutation step size upper threshold: {}',
                       arguments.mutation_step_size_upper_threshold))
        output.append(str.format('Tau 1: {}', arguments.tau1))

        if arguments.genotype == uncorrelated_n_step_size_genotype_choice:
            output.append(str.format('Tau 2: {}', arguments.tau2))

    output.append(str.format('# Calculated summary'))
    number_of_nn_weights = get_number_of_nn_weights(
        input_nodes, arguments.hidden_layer_nodes, output_nodes)
    output.append(
        str.format('Number of neural network weights: {}',
                   number_of_nn_weights))

    # Calculating memory consumption
    demo_genotype_iterator = range(memory_consumption_probe)

    if arguments.genotype == simple_genotype_choice:
        demo_genotypes = map(
            lambda index: SimpleGenotype.get_random_genotype(
                number_of_nn_weights, arguments.weight_lower_threshold,
                arguments.weight_upper_threshold), demo_genotype_iterator)
        demo_genotype_sizes = list(
            map(lambda demo_genotype: getsizeof(demo_genotype.weights),
                demo_genotypes))

    if arguments.genotype == uncorrelated_one_step_size_genotype_choice:
        demo_genotypes = map(
            lambda index: UncorrelatedOneStepSizeGenotype.get_random_genotype(
                number_of_nn_weights, arguments.weight_lower_threshold,
                arguments.weight_upper_threshold, arguments.
                mutation_step_size_lower_threshold, arguments.
                mutation_step_size_upper_threshold), demo_genotype_iterator)
        demo_genotype_sizes = list(
            map(
                lambda demo_genotype: getsizeof(demo_genotype.weights) +
                getsizeof([demo_genotype.mutation_step_size]), demo_genotypes))

    if arguments.genotype == uncorrelated_n_step_size_genotype_choice:
        demo_genotypes = map(
            lambda index: UncorrelatedNStepSizeGenotype.get_random_genotype(
                number_of_nn_weights, arguments.weight_lower_threshold,
                arguments.weight_upper_threshold, arguments.
                mutation_step_size_lower_threshold, arguments.
                mutation_step_size_upper_threshold), demo_genotype_iterator)
        demo_genotype_sizes = list(
            map(
                lambda demo_genotype: getsizeof(demo_genotype.weights) +
                getsizeof(demo_genotype.mutation_step_sizes), demo_genotypes))

    mean_demo_genotype_size = mean(
        demo_genotype_sizes) * arguments.population_size
    output.append(
        str.format('Calculated memory consumption (Python list): {}',
                   naturalsize(mean_demo_genotype_size)))
    output.append(
        str.format('Approximate end time: {}',
                   get_end_datetime(arguments.duration).isoformat(sep=' ')))

    output.append(str.format('# Utils summary'))
    output.append(str.format('Epoch summary: {}', arguments.epoch_summary))
    if arguments.epoch_summary:
        output.append(
            str.format('Epoch summary features: {}',
                       arguments.epoch_summary_features))
        output.append(
            str.format('Epoch summary interval: {}',
                       arguments.epoch_summary_interval))

    output.append(
        str.format('Population backup summary: {}',
                   arguments.population_backup))
    if arguments.population_backup:
        output.append(
            str.format('Population backup directory: {}',
                       arguments.population_backup_directory))
        output.append(
            str.format('Population backup interval: {}',
                       arguments.population_backup_interval))
        output.append(
            str.format('Population backup file extension: .{}',
                       arguments.population_backup_file_extension))

    if arguments.initial_population_directory:
        output.append(
            str.format('Initial population directory: {}',
                       arguments.initial_population_directory))
        output.append(
            str.format('Initial population file extension: {}',
                       arguments.initial_population_file_extension))

    return '\n'.join(output)
Beispiel #12
0
def handle_evolution_run(input_nodes, output_nodes, evaluation_strategy):
    argument_parser = get_core_argument_parser()
    arguments = argument_parser.parse_args()
    evolution_summary = get_evolution_summary(arguments, input_nodes,
                                              output_nodes)

    print(evolution_summary)
    if not arguments.dry_run:

        def phenotype_strategy(genotype):
            return Phenotype.get_phenotype_from_genotype(
                genotype, input_nodes, arguments.hidden_layer_nodes,
                output_nodes, arguments.bias)

        def parent_selection_strategy(phenotype_values):
            if arguments.tournament_size == 2:
                return get_two_size_tournament_parent_selection(
                    phenotype_values, arguments.population_size)
            else:
                return get_n_size_tournament_parent_selection(
                    phenotype_values, arguments.tournament_size,
                    arguments.population_size)

        def mutation_strategy(genotype):
            if arguments.genotype == simple_genotype_choice:
                return SimpleGenotype.get_mutated_genotype(
                    genotype, arguments.mutation_mean,
                    arguments.mutation_standard_deviation)

            if arguments.genotype == uncorrelated_one_step_size_genotype_choice:
                return UncorrelatedOneStepSizeGenotype.get_mutated_genotype(
                    genotype, arguments.tau1)

            if arguments.genotype == uncorrelated_n_step_size_genotype_choice:
                return UncorrelatedNStepSizeGenotype.get_mutated_genotype(
                    genotype, arguments.tau1, arguments.tau2)

        def offspring_selection_strategy(parents, mutated_parents):
            return get_age_based_offspring_selection(parents, mutated_parents)

        number_of_nn_weights = get_number_of_nn_weights(
            input_nodes, arguments.hidden_layer_nodes, output_nodes,
            arguments.bias)
        if arguments.genotype == simple_genotype_choice:
            initial_population = SimpleGenotype.get_random_genotypes(
                arguments.population_size, number_of_nn_weights,
                arguments.weight_lower_threshold,
                arguments.weight_upper_threshold)

        if arguments.genotype == uncorrelated_one_step_size_genotype_choice:
            initial_population = UncorrelatedOneStepSizeGenotype.get_random_genotypes(
                arguments.population_size, number_of_nn_weights,
                arguments.weight_lower_threshold,
                arguments.weight_upper_threshold,
                arguments.mutation_step_size_lower_threshold,
                arguments.mutation_step_size_upper_threshold)

        if arguments.genotype == uncorrelated_n_step_size_genotype_choice:
            initial_population = UncorrelatedNStepSizeGenotype.get_random_genotypes(
                arguments.population_size, number_of_nn_weights,
                arguments.weight_lower_threshold,
                arguments.weight_upper_threshold,
                arguments.mutation_step_size_lower_threshold,
                arguments.mutation_step_size_upper_threshold)

        population_backup = None
        if arguments.population_backup:
            population_backup = [
                arguments.population_backup_interval,
                arguments.population_backup_directory,
                arguments.population_backup_file_extension
            ]

        epoch_summary_strategy = None
        if arguments.epoch_summary:
            summary_features = list(
                map(lambda summary_choice: summary_lookup[summary_choice],
                    arguments.epoch_summary_features))
            epoch_summary_strategy = [
                summary_features, arguments.epoch_summary_interval
            ]

        if arguments.initial_population_directory:
            initial_population = handle_backup_load(
                arguments.initial_population_directory,
                arguments.initial_population_file_extension)

        return get_evolved_population(
            initial_population, phenotype_strategy, evaluation_strategy,
            parent_selection_strategy, mutation_strategy,
            offspring_selection_strategy, arguments.duration,
            population_backup, epoch_summary_strategy)
Beispiel #13
0
def test_snake_moves():
    def game_representation_strategy(game):
        return Game.get_full_game_representation_strategy(game)

    snake_length = 5
    width = 32
    height = 18

    input_nodes = width * height + 4
    hidden_layer_nodes = [64]
    output_nodes = 3
    number_of_nn_weights = get_number_of_nn_weights(input_nodes,
                                                    hidden_layer_nodes,
                                                    output_nodes)
    weight_lower_threshold = -1
    weight_upper_threshold = 1

    sample_genotype = SimpleGenotype.get_random_genotype(
        number_of_nn_weights, weight_lower_threshold, weight_upper_threshold)
    sample_phenotype = Phenotype(sample_genotype.weights, input_nodes,
                                 hidden_layer_nodes, output_nodes)
    sample_game = Game(width, height, sample_phenotype, 777,
                       game_representation_strategy, snake_length)

    # Test snake going LEFT and changed to LEFT
    game_copy = deepcopy(sample_game)
    game_copy.move_forward(Direction.LEFT)
    correct_snake_blocks = [(15, 9), (16, 9), (17, 9), (18, 9), (19, 9)]
    assert game_copy.snake == correct_snake_blocks
    assert len(game_copy.snake) == snake_length

    # Test snake going LEFT and changed to RIGHT
    game_copy = deepcopy(sample_game)
    game_copy.move_forward(Direction.RIGHT)
    correct_snake_blocks = [(15, 9), (16, 9), (17, 9), (18, 9), (19, 9)]
    assert game_copy.snake == correct_snake_blocks
    assert len(game_copy.snake) == snake_length

    # Test snake going LEFT and changed to UP
    game_copy = deepcopy(sample_game)
    game_copy.move_forward(Direction.UP)
    correct_snake_blocks = [(16, 8), (16, 9), (17, 9), (18, 9), (19, 9)]
    assert game_copy.snake == correct_snake_blocks
    assert len(game_copy.snake) == snake_length

    # Test snake going LEFT and changed to DOWN
    game_copy = deepcopy(sample_game)
    game_copy.move_forward(Direction.DOWN)
    correct_snake_blocks = [(16, 10), (16, 9), (17, 9), (18, 9), (19, 9)]
    assert game_copy.snake == correct_snake_blocks
    assert len(game_copy.snake) == snake_length

    # Test snake going RIGHT and changed to LEFT
    game_copy = deepcopy(sample_game)

    # Set up snake rotated to RIGHT
    game_copy.snake = [(20, 9), (19, 9), (18, 9), (17, 9), (16, 9)]
    game_copy.direction = Direction.RIGHT

    correct_snake_blocks = [(21, 9), (20, 9), (19, 9), (18, 9), (17, 9)]
    game_copy.move_forward(Direction.LEFT)
    assert game_copy.snake == correct_snake_blocks
    assert len(game_copy.snake) == snake_length

    # Test snake going RIGHT and changed to RIGHT
    game_copy = deepcopy(sample_game)

    # Set up snake rotated to RIGHT
    game_copy.snake = [(20, 9), (19, 9), (18, 9), (17, 9), (16, 9)]
    game_copy.direction = Direction.RIGHT

    correct_snake_blocks = [(21, 9), (20, 9), (19, 9), (18, 9), (17, 9)]
    game_copy.move_forward(Direction.RIGHT)
    assert game_copy.snake == correct_snake_blocks
    assert len(game_copy.snake) == snake_length

    # Test snake going RIGHT and changed to UP
    game_copy = deepcopy(sample_game)

    # Set up snake rotated to RIGHT
    game_copy.snake = [(20, 9), (19, 9), (18, 9), (17, 9), (16, 9)]
    game_copy.direction = Direction.RIGHT

    correct_snake_blocks = [(20, 8), (20, 9), (19, 9), (18, 9), (17, 9)]
    game_copy.move_forward(Direction.UP)
    assert game_copy.snake == correct_snake_blocks
    assert len(game_copy.snake) == snake_length

    # Test snake going RIGHT and changed to DOWN
    game_copy = deepcopy(sample_game)

    # Set up snake rotated to RIGHT
    game_copy.snake = [(20, 9), (19, 9), (18, 9), (17, 9), (16, 9)]
    game_copy.direction = Direction.RIGHT

    correct_snake_blocks = [(20, 10), (20, 9), (19, 9), (18, 9), (17, 9)]
    game_copy.move_forward(Direction.DOWN)
    assert game_copy.snake == correct_snake_blocks
    assert len(game_copy.snake) == snake_length

    # Test snake going UP and changed to LEFT
    game_copy = deepcopy(sample_game)

    # Set up snake rotated to UP
    game_copy.snake = [(16, 9), (16, 8), (16, 7), (16, 6), (16, 5)]
    game_copy.direction = Direction.UP

    correct_snake_blocks = [(15, 9), (16, 9), (16, 8), (16, 7), (16, 6)]
    game_copy.move_forward(Direction.LEFT)
    assert game_copy.snake == correct_snake_blocks
    assert len(game_copy.snake) == snake_length

    # Test snake going UP and changed to RIGHT
    game_copy = deepcopy(sample_game)

    # Set up snake rotated to UP
    game_copy.snake = [(16, 5), (16, 6), (16, 7), (16, 8), (16, 9)]
    game_copy.direction = Direction.UP

    correct_snake_blocks = [(17, 5), (16, 5), (16, 6), (16, 7), (16, 8)]
    game_copy.move_forward(Direction.RIGHT)
    assert game_copy.snake == correct_snake_blocks
    assert len(game_copy.snake) == snake_length

    # Test snake going UP and changed to UP
    game_copy = deepcopy(sample_game)

    # Set up snake rotated to UP
    game_copy.snake = [(16, 5), (16, 6), (16, 7), (16, 8), (16, 9)]
    game_copy.direction = Direction.UP

    correct_snake_blocks = [(16, 4), (16, 5), (16, 6), (16, 7), (16, 8)]
    game_copy.move_forward(Direction.UP)
    assert game_copy.snake == correct_snake_blocks
    assert len(game_copy.snake) == snake_length

    # Test snake going UP and changed to DOWN
    game_copy = deepcopy(sample_game)

    # Set up snake rotated to UP
    game_copy.snake = [(16, 5), (16, 6), (16, 7), (16, 8), (16, 9)]
    game_copy.direction = Direction.UP

    correct_snake_blocks = [(16, 4), (16, 5), (16, 6), (16, 7), (16, 8)]
    game_copy.move_forward(Direction.DOWN)
    assert game_copy.snake == correct_snake_blocks
    assert len(game_copy.snake) == snake_length

    # Test snake going DOWN and changed to LEFT
    game_copy = deepcopy(sample_game)

    # Set up snake rotated DOWN
    game_copy.snake = [(16, 9), (16, 8), (16, 7), (16, 6), (16, 5)]
    game_copy.direction = Direction.DOWN

    correct_snake_blocks = [(15, 9), (16, 9), (16, 8), (16, 7), (16, 6)]
    game_copy.move_forward(Direction.LEFT)
    assert game_copy.snake == correct_snake_blocks
    assert len(game_copy.snake) == snake_length

    # Test snake going DOWN and changed to RIGHT
    game_copy = deepcopy(sample_game)

    # Set up snake rotated DOWN
    game_copy.snake = [(16, 9), (16, 8), (16, 7), (16, 6), (16, 5)]
    game_copy.direction = Direction.DOWN

    correct_snake_blocks = [(17, 9), (16, 9), (16, 8), (16, 7), (16, 6)]
    game_copy.move_forward(Direction.RIGHT)
    assert game_copy.snake == correct_snake_blocks
    assert len(game_copy.snake) == snake_length

    # Test snake going DOWN and changed to UP
    game_copy = deepcopy(sample_game)

    # Set up snake rotated DOWN
    game_copy.snake = [(16, 9), (16, 8), (16, 7), (16, 6), (16, 5)]
    game_copy.direction = Direction.DOWN

    correct_snake_blocks = [(16, 10), (16, 9), (16, 8), (16, 7), (16, 6)]
    game_copy.move_forward(Direction.UP)
    assert game_copy.snake == correct_snake_blocks
    assert len(game_copy.snake) == snake_length

    # Test snake going DOWN and changed to DOWN
    game_copy = deepcopy(sample_game)

    # Set up snake rotated DOWN
    game_copy.snake = [(16, 9), (16, 8), (16, 7), (16, 6), (16, 5)]
    game_copy.direction = Direction.DOWN

    correct_snake_blocks = [(16, 10), (16, 9), (16, 8), (16, 7), (16, 6)]
    game_copy.move_forward(Direction.DOWN)
    assert game_copy.snake == correct_snake_blocks
    assert len(game_copy.snake) == snake_length