Exemplo n.º 1
0
 def _get_correctness(self):
     state = cube.get_solved()
     state = cube.rotate(state, 0, True)
     state = cube.rotate(state, 5, False)
     correctness = torch.tensor([
         [1, 1, 1, 1, -1, -1, -1, 1],
         [-1, 1, 1, 1, 1, 1, -1, -1],
         [-1, -1, -1, -1, -1, 1, 1, 1],
         [-1, -1, -1, -1, -1, 1, 1, 1],
         [-1, 1, 1, 1, 1, 1, -1, -1],
         [1, 1, -1, -1, -1, 1, 1, 1],
     ],
                                device=gpu)
     assert torch.all(correctness == cube.as_correct(
         torch.from_numpy(state).unsqueeze(0)))
Exemplo n.º 2
0
 def _can_win_all_easy_games(self, agent):
     state, i, j = cube.scramble(2, force_not_solved=True)
     is_solved = agent.search(state, time_limit=1)
     if is_solved:
         for action in agent.action_queue:
             state = cube.rotate(state, *cube.action_space[action])
         assert cube.is_solved(state)
Exemplo n.º 3
0
	def search(self, state: np.ndarray, time_limit: float=None, max_states: int=None) -> (np.ndarray, bool):
		time_limit, max_states = self.reset(time_limit, max_states)
		self.tt.tick()

		if cube.is_solved(state): return True

		# Each element contains the state from which it came and the action taken to get to it
		self.states = { state.tostring(): (None, None) }
		queue = deque([state])
		while self.tt.tock() < time_limit and len(self) < max_states:
			state = queue.popleft()
			tstate = state.tostring()
			for i, action in enumerate(cube.action_space):
				new_state = cube.rotate(state, *action)
				new_tstate = new_state.tostring()
				if new_tstate in self.states:
					continue
				elif cube.is_solved(new_state):
					self.action_queue.appendleft(i)
					while self.states[tstate][0] is not None:
						self.action_queue.appendleft(self.states[tstate][1])
						tstate = self.states[tstate][0]
					return True
				else:
					self.states[new_tstate] = (tstate, i)
					queue.append(new_state)

		return False
Exemplo n.º 4
0
 def _multi_rotate_test(self):
     states = np.array([cube.get_solved()] * 5)
     for _ in range(10):
         faces, dirs = np.random.randint(0, 6,
                                         5), np.random.randint(0, 1, 5)
         states_classic = np.array([
             cube.rotate(state, face, d)
             for state, face, d in zip(states, faces, dirs)
         ])
         states = cube.multi_rotate(states, faces, dirs)
         assert (states_classic == states).all()
Exemplo n.º 5
0
 def rotate(self, n: int):
     self.log.section(
         f"Benchmarking {TickTock.thousand_seps(n)} single rotations, {_repstr()}"
     )
     faces, dirs = np.random.randint(0, 6, n), np.random.randint(0, 2, n)
     state = cube.get_solved()
     pname = f"Single rotation, {_repstr()}"
     for f, d in zip(faces, dirs):
         self.tt.profile(pname)
         state = cube.rotate(state, f, d)
         self.tt.end_profile()
     self._log_method_results("Average rotation time", pname)
Exemplo n.º 6
0
    def test_scramble(self):
        np.random.seed(42)
        state = cube.get_solved()
        state, faces, dirs = cube.scramble(1)
        assert not cube.is_solved(state)

        state = cube.get_solved()
        state, faces, dirs = cube.scramble(20)
        assert not cube.is_solved(state)
        for f, d in zip(reversed(faces),
                        reversed([int(not item) for item in dirs])):
            state = cube.rotate(state, *(f, d))
        assert cube.is_solved(state)
Exemplo n.º 7
0
 def test_expansion(self):
     net = Model.create(ModelConfig()).eval()
     init_state, _, _ = cube.scramble(3)
     agent = AStar(net, lambda_=0.1, expansions=5)
     agent.search(init_state, time_limit=1)
     init_idx = agent.indices[init_state.tostring()]
     assert init_idx == 1
     assert agent.G[init_idx] == 0
     for action in cube.action_space:
         substate = cube.rotate(init_state, *action)
         idx = agent.indices[substate.tostring()]
         assert agent.G[idx] == 1
         assert agent.parents[idx] == init_idx
Exemplo n.º 8
0
    def _mcts_test(self, state: np.ndarray, search_graph: bool):
        agent = MCTS(Model.create(ModelConfig()),
                     c=1,
                     search_graph=search_graph)
        solved = agent.search(state, .2)

        # Indices
        assert agent.indices[state.tostring()] == 1
        for s, i in agent.indices.items():
            assert agent.states[i].tostring() == s
        assert sorted(agent.indices.values())[0] == 1
        assert np.all(np.diff(sorted(agent.indices.values())) == 1)

        used_idcs = np.array(list(agent.indices.values()))

        # States
        assert np.all(agent.states[1] == state)
        for i, s in enumerate(agent.states):
            if i not in used_idcs: continue
            assert s.tostring() in agent.indices
            assert agent.indices[s.tostring()] == i

        # Neighbors
        if not search_graph:
            for i, neighs in enumerate(agent.neighbors):
                if i not in used_idcs: continue
                state = agent.states[i]
                for j, neighbor_index in enumerate(neighs):
                    assert neighbor_index == 0 or neighbor_index in agent.indices.values(
                    )
                    if neighbor_index == 0: continue
                    substate = cube.rotate(state, *cube.action_space[j])
                    assert np.all(agent.states[neighbor_index] == substate)

        # Policy and value
        with torch.no_grad():
            p, v = agent.net(cube.as_oh(agent.states[used_idcs]))
        p, v = p.softmax(dim=1).cpu().numpy(), v.squeeze().cpu().numpy()
        assert np.all(np.isclose(agent.P[used_idcs], p, atol=1e-5))
        assert np.all(np.isclose(agent.V[used_idcs], v, atol=1e-5))

        # Leaves
        if not search_graph:
            assert np.all(agent.neighbors.all(axis=1) != agent.leaves)

        # W
        assert agent.W[used_idcs].all()

        return agent, solved
Exemplo n.º 9
0
    def _rotation_tests(self):
        state = cube.get_solved()
        for action in cube.action_space:
            state = cube.rotate(state, *action)
        # Tests that stringify and by extensions as633 works on assembled
        state = cube.get_solved()
        assert cube.stringify(state) == "\n".join([
            "      2 2 2            ",
            "      2 2 2            ",
            "      2 2 2            ",
            "4 4 4 0 0 0 5 5 5 1 1 1",
            "4 4 4 0 0 0 5 5 5 1 1 1",
            "4 4 4 0 0 0 5 5 5 1 1 1",
            "      3 3 3            ",
            "      3 3 3            ",
            "      3 3 3            ",
        ])
        # Performs moves and checks if are assembled/not assembled as expected
        moves = ((0, 1), (0, 0), (0, 1), (1, 1), (2, 0), (3, 0))
        assembled = (False, True, False, False, False, False)
        for m, a in zip(moves, assembled):
            state = cube.rotate(state, *m)
            assert a == cube.is_solved(state)

        # Tests more moves
        moves = ((3, 1), (2, 1), (1, 0), (0, 0))
        assembled = (False, False, False, True)
        for m, a in zip(moves, assembled):
            state = cube.rotate(state, *m)
            assert a == cube.is_solved(state)

        # Performs move and checks if it fits with how the string representation would look
        state = cube.get_solved()
        state = cube.rotate(state, *(0, 1))
        assert cube.stringify(state) == "\n".join([
            "      2 2 2            ",
            "      2 2 2            ",
            "      5 5 5            ",
            "4 4 2 0 0 0 3 5 5 1 1 1",
            "4 4 2 0 0 0 3 5 5 1 1 1",
            "4 4 2 0 0 0 3 5 5 1 1 1",
            "      4 4 4            ",
            "      3 3 3            ",
            "      3 3 3            ",
        ])

        # Performs all moves and checks if result fits with how it theoretically should look
        state = cube.get_solved()
        moves = ((0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (0, 1),
                 (1, 1), (2, 1), (3, 1), (4, 1), (5, 1))
        assembled = (False, False, False, False, False, False, False, False,
                     False, False, False, False)
        for m, a in zip(moves, assembled):
            state = cube.rotate(state, *m)
            assert a == cube.is_solved(state)
        assert cube.stringify(state) == "\n".join([
            "      2 0 2            ",
            "      5 2 4            ",
            "      2 1 2            ",
            "4 2 4 0 2 0 5 2 5 1 2 1",
            "4 4 4 0 0 0 5 5 5 1 1 1",
            "4 3 4 0 3 0 5 3 5 1 3 1",
            "      3 1 3            ",
            "      5 3 4            ",
            "      3 0 3            ",
        ])
Exemplo n.º 10
0
 def _test_agents(self, agent: Agent):
     state, _, _ = cube.scramble(4)
     solution_found = agent.search(state, .05)
     for action in agent.action_queue:
         state = cube.rotate(state, *cube.action_space[action])
     assert solution_found == cube.is_solved(state)
Exemplo n.º 11
0
def _action_queue_test(state, agent, sol_found):
    assert all([0 <= x < cube.action_dim for x in agent.action_queue])
    for action in agent.action_queue:
        state = cube.rotate(state, *cube.action_space[action])
    assert cube.is_solved(state) == sol_found
Exemplo n.º 12
0
	def _step(self, state: np.ndarray) -> (int, np.ndarray, bool):
		action = np.random.randint(cube.action_dim)
		state = cube.rotate(state, *cube.action_space[action])
		return action, state, cube.is_solved(state)
Exemplo n.º 13
0
	def _step(self, state: np.ndarray) -> (int, np.ndarray, bool):
		policy = torch.nn.functional.softmax(self.net(cube.as_oh(state), value=False).cpu(), dim=1).numpy().squeeze()
		action = np.random.choice(cube.action_dim, p=policy) if self.sample_policy else policy.argmax()
		state = cube.rotate(state, *cube.action_space[action])
		return action, state, cube.is_solved(state)