def test_batch_eval_state(self):
     policy = ResnetPolicy(
         ["board", "liberties", "sensibleness", "capture_size"])
     results = policy.batch_eval_state([GameState(), GameState()])
     self.assertEqual(len(results), 2)  # one result per GameState
     self.assertEqual(len(results[0]),
                      361)  # each one has 361 (move,prob) pairs
Example #2
0
	def test_save_load(self):
		"""Identical to above test_save_load
		"""
		policy = ResnetPolicy(["board", "liberties", "sensibleness", "capture_size"])

		model_file = 'TESTPOLICY.json'
		weights_file = 'TESTWEIGHTS.h5'
		model_file2 = 'TESTPOLICY2.json'
		weights_file2 = 'TESTWEIGHTS2.h5'

		# test saving model/weights separately
		policy.save_model(model_file)
		policy.model.save_weights(weights_file, overwrite=True)
		# test saving them together
		policy.save_model(model_file2, weights_file2)

		copypolicy = ResnetPolicy.load_model(model_file)
		copypolicy.model.load_weights(weights_file)

		copypolicy2 = ResnetPolicy.load_model(model_file2)

		for w1, w2 in zip(copypolicy.model.get_weights(), copypolicy2.model.get_weights()):
			self.assertTrue(np.all(w1 == w2))

		# check that save/load keeps the ResnetPolicy class
		self.assertTrue(type(policy) == type(copypolicy))

		os.remove(model_file)
		os.remove(weights_file)
		os.remove(model_file2)
		os.remove(weights_file2)
    def test_save_load(self):
        """Identical to above test_save_load
		"""
        policy = ResnetPolicy(["board", "liberties", "sensibleness", "capture_size"])

        model_file = "TESTPOLICY.json"
        weights_file = "TESTWEIGHTS.h5"
        model_file2 = "TESTPOLICY2.json"
        weights_file2 = "TESTWEIGHTS2.h5"

        # test saving model/weights separately
        policy.save_model(model_file)
        policy.model.save_weights(weights_file, overwrite=True)
        # test saving them together
        policy.save_model(model_file2, weights_file2)

        copypolicy = ResnetPolicy.load_model(model_file)
        copypolicy.model.load_weights(weights_file)

        copypolicy2 = ResnetPolicy.load_model(model_file2)

        for w1, w2 in zip(copypolicy.model.get_weights(), copypolicy2.model.get_weights()):
            self.assertTrue(np.all(w1 == w2))

            # check that save/load keeps the ResnetPolicy class
        self.assertTrue(type(policy) == type(copypolicy))

        os.remove(model_file)
        os.remove(weights_file)
        os.remove(model_file2)
        os.remove(weights_file2)
 def test_batch_eval_state(self):
     policy = ResnetPolicy(["board", "liberties", "sensibleness", "capture_size"])
     results = policy.batch_eval_state([GameState(), GameState()])
     self.assertEqual(len(results), 2)  # one result per GameState
     self.assertEqual(len(results[0]), 361)  # each one has 361 (move,prob) pairs
 def test_default_policy(self):
     policy = ResnetPolicy(["board", "liberties", "sensibleness", "capture_size"])
     policy.eval_state(GameState())
Example #6
0
	def test_default_policy(self):
		policy = ResnetPolicy(["board", "liberties", "sensibleness", "capture_size"])
		policy.eval_state(GameState())