예제 #1
0
파일: trainer.py 프로젝트: pikma/rubik
    def generate_td_value_examples() -> Iterator[Tuple[np.ndarray, float]]:
        '''Generates training examples.'''
        batcher = util.ModelBatcher(1024 * 4, model, feature_shape=(20, 24))

        while True:
            for cube in cube_lib.scramble_cube(TRAJECTORY_LENGTH):
                next_value = None
                cube_features = cube.as_numpy_array()
                next_cube_features = []

                for rotation in cube_lib.Rotation.all():
                    next_cube = cube.copy()
                    next_cube.rotate_face(rotation)
                    if next_cube.is_solved():
                        next_value = 1
                        break
                    else:
                        next_cube_features.append(next_cube.as_numpy_array())

                if next_value:
                    yield (cube_features, 1)
                    continue

                next_cube_features = np.asarray(next_cube_features)
                batcher.enqueue_predictions(next_cube_features,
                                            request_id=cube_features)

                for next_cube_predictions, cube_features in (
                        batcher.get_predictions()):
                    yield (cube_features, 1 + np.min(next_cube_predictions))
예제 #2
0
    def _get_next_rotation(self) -> Optional[cube_lib.Rotation]:
        if self._solution:
            return self._solution.popleft()
        if self._solution is not None:
            # We are done applying the rotations in the solution.
            assert self.cube.is_solved
            return None

        if self.cube.is_solved():
            return None

        done_states: Dict[cube_lib.Cube, _AStarState] = {}
        queue = _PriorityQueue()

        est_distance = self._model.predict(self.cube.as_numpy_array().reshape(
            (1, 20, 24))).item()
        queue.add_or_update_state(
            _AStarState(cube=self.cube,
                        cost_to_come=0,
                        est_cost_to_go=est_distance,
                        previous_state=None,
                        previous_rotation=None))

        while queue:
            state = queue.pop_min_state()
            if state.cube.is_solved():
                self._solution = self._compute_solution(state, done_states)
                # The solution is cached, we can just recurse and it will read
                # the first move from the cached solution.
                return self._get_next_rotation()

            model_batcher = util.ModelBatcher(cube_lib.NUM_ROTATIONS,
                                              self._model,
                                              feature_shape=(20, 24))

            for rotation in cube_lib.Rotation.all():
                new_cube = state.cube.copy()
                new_cube.rotate_face(rotation)

                new_state = _AStarState(
                    cube=new_cube,
                    cost_to_come=state.cost_to_come + 1,
                    est_cost_to_go=-1,  # will be set with the model prediction.
                    previous_state=state.cube,
                    previous_rotation=rotation)

                model_batcher.enqueue_predictions(
                    new_cube.as_numpy_array().reshape((1, 20, 24)),
                    request_id=new_state)

            model_batcher.flush()
            for value, new_state in model_batcher.get_predictions():
                new_state.est_cost_to_go = value
                queue.add_or_update_state(new_state)

            done_states[state.cube] = state

        assert False  # This could should be unreachable.
예제 #3
0
    def test_request_bigger_than_batch_size(self):
        batcher = util.ModelBatcher(batch_size=2,
                                    model=FakeModel(),
                                    feature_shape=(1, ))
        first_features = np.arange(1, 6).reshape((5, 1))
        batcher.enqueue_predictions(first_features, '1')
        self.assertFalse(list(batcher.get_predictions()))
        batcher.enqueue_predictions(np.asarray([[6]]), '6')
        results = list(batcher.get_predictions())

        self.assertEqual(len(results), 2)
        self.assertTrue((results[0][0] == first_features).all())
        self.assertEqual(results[0][1], '1')
        self.assertEqual(results[1], (np.asarray([[6]]), '6'))
예제 #4
0
    def test_requests_of_one_and_two(self):
        batcher = util.ModelBatcher(batch_size=3,
                                    model=FakeModel(),
                                    feature_shape=(1, ),
                                    feature_dtype='int')
        batcher.enqueue_predictions(np.asarray([[1], [2]]), '1')
        self.assertFalse(list(batcher.get_predictions()))
        batcher.enqueue_predictions(np.asarray([[3]]), '3')

        results = list(batcher.get_predictions())
        self.assertEqual(len(results), 2)
        self.assertTrue((results[0][0] == [[1], [2]]).all())
        self.assertEqual(results[0][1], '1')
        self.assertEqual(results[1], (np.asarray([[3]]), '3'))

        batcher.enqueue_predictions(np.asarray([[4]]), '4')
        self.assertFalse(list(batcher.get_predictions()))
        batcher.flush()
        self.assertListEqual(list(batcher.get_predictions()), [
            (np.asarray([[4]]), '4'),
        ])
예제 #5
0
    def test_requests_of_one(self):
        batcher = util.ModelBatcher(batch_size=3,
                                    model=FakeModel(),
                                    feature_shape=(1, ),
                                    feature_dtype='int')
        batcher.enqueue_predictions(np.asarray([[1]]), '1')
        self.assertFalse(list(batcher.get_predictions()))
        batcher.enqueue_predictions(np.asarray([[2]]), '2')
        self.assertFalse(list(batcher.get_predictions()))
        batcher.enqueue_predictions(np.asarray([[3]]), '3')
        self.assertListEqual(list(batcher.get_predictions()), [
            (np.asarray([[1]]), '1'),
            (np.asarray([[2]]), '2'),
            (np.asarray([[3]]), '3'),
        ])

        batcher.enqueue_predictions(np.asarray([[4]]), '4')
        self.assertFalse(list(batcher.get_predictions()))
        batcher.flush()
        self.assertListEqual(list(batcher.get_predictions()), [
            (np.asarray([[4]]), '4'),
        ])
예제 #6
0
    def apply_next_rotation(self):
        if self.cube.is_solved():
            return None

        model_batcher = util.ModelBatcher(1024,
                                          self._model,
                                          feature_shape=(20, 24))
        # We do a BFS so that we traverse states in increasing order of depth.
        # That way, as soon as we encounter a solved state, we know that we
        # have found the shortest path.
        queue = collections.deque()
        queue.append(_Trajectory(final_state=self.cube))

        explored_set = {self.cube}

        best_rotation = None
        best_value = None

        def process_predictions():
            nonlocal best_rotation
            nonlocal best_value
            for value, first_rotation in model_batcher.get_predictions():
                value = value[0]
                # The model predicts the distance, so we minimize it.
                if best_value is None or value < best_value:
                    best_value = value
                    best_rotation = first_rotation

        while queue:
            trajectory = queue.pop()
            state = trajectory.final_state

            if state.is_solved():
                # We know this is the shortest trajectory since we are doing a
                # BFS.
                self.cube.rotate_face(trajectory.first_rotation)
                return trajectory.first_rotation

            if trajectory.num_rotations >= self._depth:
                # We reached a leaf in the tree, therefore we use the model to
                # evaluate the state.
                model_batcher.enqueue_predictions(
                    state.as_numpy_array().reshape((1, 20, 24)),
                    request_id=trajectory.first_rotation)

                process_predictions()
            else:
                # This isn't a leaf state, we expand it.
                for rotation in cube_lib.Rotation.all():
                    new_state = state.copy()
                    new_state.rotate_face(rotation)

                    if new_state not in explored_set:
                        explored_set.add(new_state)
                        # The first_rotation is set to None for empty
                        # trajectory (the one with the root state). If that's
                        # the case, then we are currently at the first
                        # rotation.
                        first_rotation = trajectory.first_rotation or rotation
                        queue.appendleft(
                            _Trajectory(
                                final_state=new_state,
                                first_rotation=first_rotation,
                                num_rotations=trajectory.num_rotations + 1))

        model_batcher.flush()
        process_predictions()

        self.cube.rotate_face(best_rotation)
        return best_rotation