Ejemplo n.º 1
0
def test_simple():
    for game in games:
        sm = lookup.by_name(game).get_sm()
        bs_0 = sm.get_initial_state()

        bs_1 = sm.new_base_state()
        bs_1.assign(bs_0)
        for i in range(3):
            advance_state(sm, bs_1)

        assert bs_0 != bs_1

        l0 = decode_state(encode_state(bs_0.to_list()))
        l1 = decode_state(encode_state(bs_1.to_list()))

        decode_bs_0 = sm.new_base_state()
        decode_bs_1 = sm.new_base_state()
        decode_bs_0.from_list(l0)
        decode_bs_1.from_list(l1)

        assert bs_0.to_string() == bs_0.to_string()

        assert decode_bs_0 == bs_0
        assert decode_bs_0.hash_code() == bs_0.hash_code()

        print len(decode_bs_0.to_string())
        print len(bs_0.to_string())

        #assert decode_bs_0.to_string() == bs_0.to_string()

        assert decode_bs_1 == bs_1
        assert decode_bs_1.hash_code() == bs_1.hash_code()
        assert decode_bs_1.to_string() == bs_1.to_string()
Ejemplo n.º 2
0
def test_speed():
    import time

    for game in games:
        print "doing", game
        sm = lookup.by_name(game).get_sm()

        # a couple of states
        bs_0 = sm.get_initial_state()

        bs_1 = sm.new_base_state()
        bs_1.assign(bs_0)
        for i in range(5):
            advance_state(sm, bs_1)

        # encode states
        encoded_0 = encode_state(bs_0.to_list())
        encoded_1 = encode_state(bs_1.to_list())

        assert decode_state(encoded_0) == fast_decode_state(encoded_0)
        assert decode_state(encoded_1) == fast_decode_state(encoded_1)

        s = time.time()
        for i in range(10000):
            l0 = decode_state(encoded_0)
            l1 = decode_state(encoded_1)

        print "time taken %.3f msecs" % ((time.time() - s) * 1000)

        s = time.time()
        for i in range(10000):
            l0 = fast_decode_state(encoded_0)
            l1 = fast_decode_state(encoded_1)

        print "time taken %.3f msecs" % ((time.time() - s) * 1000)
Ejemplo n.º 3
0
def test_more():
    for game in games:
        print "doing", game
        sm = lookup.by_name(game).get_sm()
        bs_0 = sm.get_initial_state()

        bs_1 = sm.new_base_state()
        bs_1.assign(bs_0)
        for i in range(5):
            advance_state(sm, bs_1)

        assert bs_0 != bs_1

        # states to compare
        decode_bs_0 = sm.new_base_state()
        decode_bs_1 = sm.new_base_state()
        decode_direct_bs_0 = sm.new_base_state()
        decode_direct_bs_1 = sm.new_base_state()

        # encode as before
        en_0 = encode_state(bs_0.to_list())
        en_1 = encode_state(bs_1.to_list())

        # decode as before
        l0 = decode_state(en_0)
        l1 = decode_state(en_1)
        decode_bs_0.from_list(l0)
        decode_bs_1.from_list(l1)

        # decode directly
        decode_direct_bs_0.from_string(base64.decodestring(en_0))
        decode_direct_bs_1.from_string(base64.decodestring(en_1))

        # all checks
        assert decode_bs_0 == bs_0
        assert decode_bs_0.hash_code() == bs_0.hash_code()
        assert decode_bs_0.to_string() == bs_0.to_string()

        assert decode_direct_bs_0 == bs_0
        assert decode_direct_bs_0.hash_code() == bs_0.hash_code()
        assert decode_direct_bs_0.to_string() == bs_0.to_string()

        assert decode_bs_1 == bs_1
        assert decode_bs_1.hash_code() == bs_1.hash_code()
        assert decode_bs_1.to_string() == bs_1.to_string()

        assert decode_direct_bs_1 == bs_1
        assert decode_direct_bs_1.hash_code() == bs_1.hash_code()
        assert decode_direct_bs_1.to_string() == bs_1.to_string()

        print "good", game
Ejemplo n.º 4
0
    def on_request_samples(self, server, msg):
        self.on_request_samples_time = time.time()

        assert self.supervisor is not None
        self.samples = []
        self.supervisor.reset_stats()

        log.debug("Got request for sample with number unique states %s" %
                  len(msg.new_states))

        # update duplicates
        for s in msg.new_states:
            self.supervisor.add_unique_state(decode_state(s))

        start_time = time.time()
        self.supervisor.poll_loop(do_stats=True, cb=self.cb_from_superviser)

        msg = "#samp %d, pred()s %d/%d, py/pred/all %.1f/%.1f/%.1f"
        log.info(
            msg %
            (len(self.samples), self.supervisor.num_predictions_calls,
             self.supervisor.total_predictions,
             self.supervisor.acc_time_polling,
             self.supervisor.acc_time_prediction, time.time() - start_time))

        m = msgs.RequestSampleResponse(self.samples, 0)
        server.send_msg(m)
Ejemplo n.º 5
0
    def verify_samples(self, sm):
        # create a basestate
        basestate = sm.new_base_state()

        counters = [Counter(), Counter()]
        max_values = [{}, {}]
        min_values = [{}, {}]
        for s in self.samples:
            basestate.from_list(decode_state(s.state))
            sm.update_bases(basestate)

            # get legals...
            for ri in range(2):
                ls = sm.get_legal_state(ri)
                policy = s.policies[ri]
                for legal in ls.to_list():
                    found = False
                    for ll, pp in policy:
                        if ll == legal:
                            max_values[ri][legal] = max(
                                max_values[ri].get(legal, -1), pp)
                            min_values[ri][legal] = min(
                                max_values[ri].get(legal, 2), pp)
                            found = True
                            break
                    assert found
                    counters[ri][legal] += 1
Ejemplo n.º 6
0
    def sample_to_nn(self, sample, inputs, outputs):
        # transform samples -> numpy arrays as inputs/outputs to nn

        # input - planes
        inputs.append(
            self.state_to_channels(decode_state(
                sample.state), [decode_state(s) for s in sample.prev_states]))

        output = []

        # output - policies
        assert self.role_count == 2
        assert len(self.policy_dist_count) == 2
        for i in range(self.role_count):
            array = self.policy_to_array(sample.policies[i], i)
            output.append(array)

        # output - best/final scores
        output.append(np.array(sample.final_score, dtype='float32'))
        outputs.append(output)
Ejemplo n.º 7
0
    def check_sample(self, sample):
        # XXX this should be ==.  But since our encode/decode can end up padding
        assert len(decode_state(sample.state)) >= self.num_bases
        assert len(sample.final_score) == self.final_score_count

        assert isinstance(sample, datadesc.Sample)
        for policy in sample.policies:
            total = 0.0
            for legal, p in policy:
                assert -0.01 < p < 1.01
                total += p

            assert 0.99 < total < 1.01

        return sample