Example #1
0
    def __init__(self, spec, compute_transitions=True):
        """
        height: Integer, height of the grid. Y coordinates are in [0, height).
        width: Integer, width of the grid. X coordinates are in [0, width).
        init_state: TrainState, initial state of the environment
        vase_locations: List of (x, y) tuples, locations of vases
        num_vases: Integer, number of vases
        carpet_locations: Set of (x, y) tuples, locations of carpets
        feature_locations: List of (x, y) tuples, locations of features
        s: TrainState, Current state
        nA: Integer, number of actions
        """
        self.height = spec.height
        self.width = spec.width
        self.init_state = deepcopy(spec.init_state)
        self.vase_locations = list(self.init_state.vase_states.keys())
        self.num_vases = len(self.vase_locations)
        self.carpet_locations = set(spec.carpet_locations)
        self.feature_locations = list(spec.feature_locations)
        self.train_transition = spec.train_transition
        self.train_locations = list(self.train_transition.keys())
        assert set(self.train_locations) == set(self.train_transition.values())

        self.default_action = Direction.get_number_from_direction(
            Direction.STAY)
        self.nA = 5
        self.num_features = len(self.s_to_f(self.init_state))

        self.reset()

        if compute_transitions:
            states = self.enumerate_states()
            self.make_transition_matrices(states, range(self.nA), self.nS,
                                          self.nA)
            self.make_f_matrix(self.nS, self.num_features)
Example #2
0
    def get_next_states(self, state, action):
        '''returns the next state given a state and an action'''
        action = int(action)
        orientation, x, y = state.agent_pos
        new_orientation, new_x, new_y = state.agent_pos
        new_tree_states = deepcopy(state.tree_states)
        new_bucket_states = deepcopy(state.bucket_states)
        new_carrying_apple = state.carrying_apple

        if action == Direction.get_number_from_direction(Direction.STAY):
            pass
        elif action < len(Direction.ALL_DIRECTIONS):
            new_orientation = action
            move_x, move_y = Direction.move_in_direction_number((x, y), action)
            # New position is legal
            if (0 <= move_x < self.width and \
                0 <= move_y < self.height and \
                (move_x, move_y) in self.possible_agent_locations):
                new_x, new_y = move_x, move_y
            else:
                # Move only changes orientation, which we already handled
                pass
        elif action == 5:
            obj_pos = Direction.move_in_direction_number((x, y), orientation)
            if state.carrying_apple:
                # We always drop the apple
                new_carrying_apple = False
                # If we're facing a bucket, it goes there
                if obj_pos in new_bucket_states:
                    prev_apples = new_bucket_states[obj_pos]
                    new_bucket_states[obj_pos] = min(prev_apples + 1,
                                                     self.bucket_capacity)
            elif obj_pos in new_tree_states and new_tree_states[obj_pos]:
                new_carrying_apple = True
                new_tree_states[obj_pos] = False
            else:
                # Interact while holding nothing and not facing a tree.
                pass
        else:
            raise ValueError('Invalid action {}'.format(action))

        new_pos = new_orientation, new_x, new_y

        def make_state(prob_apples_tuple):
            prob, tree_apples = prob_apples_tuple
            trees = dict(zip(self.tree_locations, tree_apples))
            s = ApplesState(new_pos, trees, new_bucket_states,
                            new_carrying_apple)
            return (prob, s, 0)

        # For apple regeneration, don't regenerate apples that were just picked,
        # so use the apple booleans from the original state
        old_tree_apples = [
            state.tree_states[loc] for loc in self.tree_locations
        ]
        new_tree_apples = [new_tree_states[loc] for loc in self.tree_locations]
        return list(
            map(make_state, self.regen_apples(old_tree_apples,
                                              new_tree_apples)))
Example #3
0
    def test_direction_number_conversion(self):
        all_directions = Direction.ALL_DIRECTIONS
        all_numbers = []

        for direction in Direction.ALL_DIRECTIONS:
            number = Direction.get_number_from_direction(direction)
            direction_again = Direction.get_direction_from_number(number)
            self.assertEqual(direction, direction_again)
            all_numbers.append(number)

        # Check that all directions are distinct
        num_directions = len(all_directions)
        self.assertEqual(len(set(all_directions)), num_directions)
        # Check that the numbers are 0, 1, ... num_directions - 1
        self.assertEqual(set(all_numbers), set(range(num_directions)))
Example #4
0
    def __init__(self, spec, compute_transitions=True):
        """
        height: Integer, height of the grid. Y coordinates are in [0, height).
        width: Integer, width of the grid. X coordinates are in [0, width).
        init_state: ApplesState, initial state of the environment
        vase_locations: List of (x, y) tuples, locations of vases
        num_vases: Integer, number of vases
        carpet_locations: Set of (x, y) tuples, locations of carpets
        feature_locations: List of (x, y) tuples, locations of features
        s: ApplesState, Current state
        nA: Integer, number of actions
        """
        self.height = spec.height
        self.width = spec.width
        self.apple_regen_probability = spec.apple_regen_probability
        self.bucket_capacity = spec.bucket_capacity
        self.init_state = deepcopy(spec.init_state)
        self.include_location_features = spec.include_location_features

        self.tree_locations = list(self.init_state.tree_states.keys())
        self.bucket_locations = list(self.init_state.bucket_states.keys())
        used_locations = set(self.tree_locations + self.bucket_locations)
        self.possible_agent_locations = list(
            filter(lambda pos: pos not in used_locations,
                   product(range(self.width), range(self.height))))

        self.num_trees = len(self.tree_locations)
        self.num_buckets = len(self.bucket_locations)

        self.default_action = Direction.get_number_from_direction(
            Direction.STAY)
        self.nA = 6
        self.num_features = len(self.s_to_f(self.init_state))

        self.reset()

        if compute_transitions:
            states = self.enumerate_states()
            self.make_transition_matrices(states, range(self.nA), self.nS,
                                          self.nA)
            self.make_f_matrix(self.nS, self.num_features)
Example #5
0
    # |  T|
    # -----
    # Where the agent has picked the right trees once and put the fruit in the
    # basket.
    'default':
    (ApplesSpec(5,
                3,
                ApplesState(agent_pos=(0, 0, 2),
                            tree_states={
                                (0, 0): True,
                                (2, 0): True,
                                (2, 4): True
                            },
                            bucket_states={(1, 2): 0},
                            carrying_apple=False),
                apple_regen_probability=0.1,
                bucket_capacity=10,
                include_location_features=True),
     ApplesState(agent_pos=(Direction.get_number_from_direction(
         Direction.SOUTH), 1, 1),
                 tree_states={
                     (0, 0): True,
                     (2, 0): False,
                     (2, 4): True
                 },
                 bucket_states={(1, 2): 2},
                 carrying_apple=False),
     np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0]), np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
}