def test_batch_eval_state(self): policy = ResnetPolicy( ["board", "liberties", "sensibleness", "capture_size"]) results = policy.batch_eval_state([GameState(), GameState()]) self.assertEqual(len(results), 2) # one result per GameState self.assertEqual(len(results[0]), 361) # each one has 361 (move,prob) pairs
def test_save_load(self): """Identical to above test_save_load """ policy = ResnetPolicy(["board", "liberties", "sensibleness", "capture_size"]) model_file = 'TESTPOLICY.json' weights_file = 'TESTWEIGHTS.h5' model_file2 = 'TESTPOLICY2.json' weights_file2 = 'TESTWEIGHTS2.h5' # test saving model/weights separately policy.save_model(model_file) policy.model.save_weights(weights_file, overwrite=True) # test saving them together policy.save_model(model_file2, weights_file2) copypolicy = ResnetPolicy.load_model(model_file) copypolicy.model.load_weights(weights_file) copypolicy2 = ResnetPolicy.load_model(model_file2) for w1, w2 in zip(copypolicy.model.get_weights(), copypolicy2.model.get_weights()): self.assertTrue(np.all(w1 == w2)) # check that save/load keeps the ResnetPolicy class self.assertTrue(type(policy) == type(copypolicy)) os.remove(model_file) os.remove(weights_file) os.remove(model_file2) os.remove(weights_file2)
def test_save_load(self): """Identical to above test_save_load """ policy = ResnetPolicy(["board", "liberties", "sensibleness", "capture_size"]) model_file = "TESTPOLICY.json" weights_file = "TESTWEIGHTS.h5" model_file2 = "TESTPOLICY2.json" weights_file2 = "TESTWEIGHTS2.h5" # test saving model/weights separately policy.save_model(model_file) policy.model.save_weights(weights_file, overwrite=True) # test saving them together policy.save_model(model_file2, weights_file2) copypolicy = ResnetPolicy.load_model(model_file) copypolicy.model.load_weights(weights_file) copypolicy2 = ResnetPolicy.load_model(model_file2) for w1, w2 in zip(copypolicy.model.get_weights(), copypolicy2.model.get_weights()): self.assertTrue(np.all(w1 == w2)) # check that save/load keeps the ResnetPolicy class self.assertTrue(type(policy) == type(copypolicy)) os.remove(model_file) os.remove(weights_file) os.remove(model_file2) os.remove(weights_file2)
def test_batch_eval_state(self): policy = ResnetPolicy(["board", "liberties", "sensibleness", "capture_size"]) results = policy.batch_eval_state([GameState(), GameState()]) self.assertEqual(len(results), 2) # one result per GameState self.assertEqual(len(results[0]), 361) # each one has 361 (move,prob) pairs
def test_default_policy(self): policy = ResnetPolicy(["board", "liberties", "sensibleness", "capture_size"]) policy.eval_state(GameState())