Ejemplo n.º 1
0
    def __init__(self, spec):
        self.init_state = deepcopy(spec.init_state)

        self.height = spec.height
        self.width = spec.width
        self.battery_locations = sorted(
            list(self.init_state.battery_present.keys()))
        self.num_batteries = len(self.battery_locations)
        self.feature_locations = list(spec.feature_locations)
        self.train_transition = spec.train_transition
        self.train_locations = list(self.train_transition.keys())
        assert set(self.train_locations) == set(self.train_transition.values())
        self.nA = 5

        super().__init__(10)

        self.default_action = Direction.get_number_from_direction(
            Direction.STAY)
        self.num_features = len(self.s_to_f(self.init_state))

        self.reset()

        states = self.enumerate_states()
        self.make_transition_matrices(states, range(self.nA), self.nS, self.nA)
        self.make_f_matrix(self.nS, self.num_features)
Ejemplo n.º 2
0
    def __init__(self, spec):
        self.height = spec.height
        self.width = spec.width
        self.init_state = deepcopy(spec.init_state)
        self.apple_regen_probability = spec.apple_regen_probability
        self.bucket_capacity = spec.bucket_capacity
        self.include_location_features = spec.include_location_features

        self.tree_locations = list(self.init_state.tree_states.keys())
        self.num_trees = len(self.tree_locations)
        self.bucket_locations = list(self.init_state.bucket_states.keys())
        self.num_buckets = len(self.bucket_locations)
        used_locations = set(self.tree_locations + self.bucket_locations)
        self.possible_agent_locations = list(
            filter(
                lambda pos: pos not in used_locations,
                product(range(self.width), range(self.height)),
            ))
        self.nA = 6

        super().__init__(max(5, self.bucket_capacity))

        self.default_action = Direction.get_number_from_direction(
            Direction.STAY)
        self.num_features = len(self.s_to_f(self.init_state))

        self.reset()

        states = self.enumerate_states()
        self.make_transition_matrices(states, range(self.nA), self.nS, self.nA)
        self.make_f_matrix(self.nS, self.num_features)
Ejemplo n.º 3
0
    def get_next_states(self, state, action):
        """Returns the next state given a state and an action."""
        action = int(action)
        orientation, x, y = state.agent_pos
        new_orientation, new_x, new_y = state.agent_pos
        new_tree_states = deepcopy(state.tree_states)
        new_bucket_states = deepcopy(state.bucket_states)
        new_carrying_apple = state.carrying_apple

        if action == Direction.get_number_from_direction(Direction.STAY):
            pass
        elif action < len(Direction.ALL_DIRECTIONS):
            new_orientation = action
            move_x, move_y = Direction.move_in_direction_number((x, y), action)
            # New position is legal
            if (0 <= move_x < self.width and 0 <= move_y < self.height
                    and (move_x, move_y) in self.possible_agent_locations):
                new_x, new_y = move_x, move_y
            else:
                # Move only changes orientation, which we already handled
                pass
        elif action == 5:
            obj_pos = Direction.move_in_direction_number((x, y), orientation)
            if state.carrying_apple:
                # We always drop the apple
                new_carrying_apple = False
                # If we're facing a bucket, it goes there
                if obj_pos in new_bucket_states:
                    prev_apples = new_bucket_states[obj_pos]
                    new_bucket_states[obj_pos] = min(prev_apples + 1,
                                                     self.bucket_capacity)
            elif obj_pos in new_tree_states and new_tree_states[obj_pos]:
                new_carrying_apple = True
                new_tree_states[obj_pos] = False
            else:
                # Interact while holding nothing and not facing a tree.
                pass
        else:
            raise ValueError("Invalid action {}".format(action))

        new_pos = new_orientation, new_x, new_y

        def make_state(prob_apples_tuple):
            prob, tree_apples = prob_apples_tuple
            trees = dict(zip(self.tree_locations, tree_apples))
            s = ApplesState(new_pos, trees, new_bucket_states,
                            new_carrying_apple)
            return (prob, s, 0)

        # For apple regeneration, don't regenerate apples that were just picked,
        # so use the apple booleans from the original state
        old_tree_apples = [
            state.tree_states[loc] for loc in self.tree_locations
        ]
        new_tree_apples = [new_tree_states[loc] for loc in self.tree_locations]
        return list(
            map(make_state, self.regen_apples(old_tree_apples,
                                              new_tree_apples)))
Ejemplo n.º 4
0
    def test_direction_number_conversion(self):
        all_directions = Direction.ALL_DIRECTIONS
        all_numbers = []

        for direction in Direction.ALL_DIRECTIONS:
            number = Direction.get_number_from_direction(direction)
            direction_again = Direction.get_direction_from_number(number)
            self.assertEqual(direction, direction_again)
            all_numbers.append(number)

        # Check that all directions are distinct
        num_directions = len(all_directions)
        self.assertEqual(len(set(all_directions)), num_directions)
        # Check that the numbers are 0, 1, ... num_directions - 1
        self.assertEqual(set(all_numbers), set(range(num_directions)))
Ejemplo n.º 5
0
    def get_next_state(self, state, action):
        """Returns the next state given a state and an action."""
        action = int(action)

        if action == Direction.get_number_from_direction(Direction.STAY):
            pass
        elif action < len(Direction.ALL_DIRECTIONS):
            move_x, move_y = Direction.move_in_direction_number(state, action)
            # New position is legal
            if 0 <= move_x < self.width and 0 <= move_y < self.height:
                state = move_x, move_y
            else:
                # Move only changes orientation, which we already handled
                pass
        else:
            raise ValueError("Invalid action {}".format(action))

        return state
Ejemplo n.º 6
0
    def __init__(self, prob, use_pixels_as_observations=True):
        self.height = 3
        self.width = 3
        self.init_state = (1, 1)
        self.prob = prob
        self.nS = self.height * self.width
        self.nA = 5

        super().__init__(1,
                         use_pixels_as_observations=use_pixels_as_observations)

        self.num_features = 2
        self.default_action = Direction.get_number_from_direction(
            Direction.STAY)
        self.num_features = len(self.s_to_f(self.init_state))

        self.reset()

        states = self.enumerate_states()
        self.make_transition_matrices(states, range(self.nA), self.nS, self.nA)
        self.make_f_matrix(self.nS, self.num_features)
Ejemplo n.º 7
0
    def __init__(self, spec):
        self.height = spec.height
        self.width = spec.width
        self.init_state = deepcopy(spec.init_state)
        self.vase_locations = list(self.init_state.vase_states.keys())
        self.num_vases = len(self.vase_locations)
        self.carpet_locations = set(spec.carpet_locations)
        self.feature_locations = list(spec.feature_locations)
        self.nA = 5

        super().__init__(
            1, use_pixels_as_observations=spec.use_pixels_as_observations)

        self.default_action = Direction.get_number_from_direction(
            Direction.STAY)
        self.num_features = len(self.s_to_f(self.init_state))

        states = self.enumerate_states()
        self.reset()
        self.make_transition_matrices(states, range(self.nA), self.nS, self.nA)
        self.make_f_matrix(self.nS, self.num_features)
Ejemplo n.º 8
0
 def _collect_data(self, n_rollouts, debug_only_stay=False):
     observations, actions = [], []
     for _ in range(n_rollouts):
         traj_len = 0
         # ensure trajectories are longer than self.timesteps
         while traj_len < self.timesteps:
             obs = self.env.reset()
             traj_act = []
             traj_obs = [obs]
             done = False
             while not done:
                 if debug_only_stay:
                     action = Direction.get_number_from_direction(
                         Direction.STAY)
                 else:
                     action = self.env.action_space.sample()
                 obs, _, done, _ = self.env.step(action)
                 traj_obs.append(obs)
                 traj_act.append(action)
             traj_len = len(traj_obs)
         traj_act.append(np.zeros(self.action_space_shape))
         observations.append(traj_obs)
         actions.append(traj_act)
     return observations, actions
Ejemplo n.º 9
0
            3,
            ApplesState(
                agent_pos=(0, 0, 2),
                tree_states={
                    (0, 0): True,
                    (2, 0): True,
                    (2, 4): True
                },
                bucket_states={(1, 2): 0},
                carrying_apple=False,
            ),
            apple_regen_probability=0.1,
            bucket_capacity=10,
            include_location_features=True,
        ),
        ApplesState(
            agent_pos=(Direction.get_number_from_direction(Direction.SOUTH), 1,
                       1),
            tree_states={
                (0, 0): True,
                (2, 0): False,
                (2, 4): True
            },
            bucket_states={(1, 2): 2},
            carrying_apple=False,
        ),
        np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
        np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
    )
}