Esempio n. 1
0
    def test_generate_saved_game_proto(self):
        """ Tests the generate_saved_game_proto method """
        from diplomacy_research.utils.tensorflow import tf
        hparams = self.parse_flags(load_args() + load_value_args() + load_draw_args())

        # Generating model
        graph = tf.Graph()
        with graph.as_default():
            dataset = QueueDataset(batch_size=32, dataset_builder=BaseDatasetBuilder())
            model = PolicyModel(dataset, hparams)
            model = ValueModel(model, dataset, hparams)
            model = DrawModel(model, dataset, hparams)
            model.finalize_build()
            adapter = PolicyAdapter(dataset, graph, tf.Session(graph=graph))
            advantage_fn = MonteCarlo(gamma=0.99)
            reward_fn = DefaultRewardFunction()

            # Creating players
            player = ModelBasedPlayer(adapter)
            rule_player = RuleBasedPlayer(easy_ruleset)
            players = [player, player, player, player, player, player, rule_player]

            def env_constructor(players):
                """ Env constructor """
                env = gym.make('DiplomacyEnv-v0')
                env = LimitNumberYears(env, 5)
                env = RandomizePlayers(env, players)
                env = AutoDraw(env)
                return env

            # Generating game
            saved_game_proto = yield generate_trajectory(players, reward_fn, advantage_fn, env_constructor)

        # Validating game
        assert saved_game_proto.id
        assert len(saved_game_proto.phases) >= 10

        # Validating policy details
        for phase in saved_game_proto.phases:
            for power_name in phase.policy:
                nb_locs = len(phase.policy[power_name].locs)
                assert (len(phase.policy[power_name].tokens) == nb_locs * TOKENS_PER_ORDER          # Token-based
                        or len(phase.policy[power_name].tokens) == nb_locs)                         # Order-based
                assert len(phase.policy[power_name].log_probs) == len(phase.policy[power_name].tokens)
                assert phase.policy[power_name].draw_action in (True, False)
                assert 0. <= phase.policy[power_name].draw_prob <= 1.

        # Validating assignments
        assert len(saved_game_proto.assigned_powers) == NB_POWERS

        # Validating rewards and returns
        assert saved_game_proto.reward_fn == DefaultRewardFunction().name
        for power_name in saved_game_proto.assigned_powers:
            assert len(saved_game_proto.rewards[power_name].value) == len(saved_game_proto.phases) - 1
            assert len(saved_game_proto.returns[power_name].value) == len(saved_game_proto.phases) - 1
Esempio n. 2
0
def launch_adapter():
    """ Launches the tests """
    testable_class = PolicyAdapterTestSetup(policy_model_ctor=PolicyModel,
                                            value_model_ctor=ValueModel,
                                            draw_model_ctor=None,
                                            dataset_builder=BaseDatasetBuilder(),
                                            policy_adapter_ctor=PolicyAdapter,
                                            load_policy_args=load_args,
                                            load_value_args=load_value_args,
                                            load_draw_args=None,
                                            strict=False)
    testable_class.run_tests()
    def build_adapter(self):
        """ Builds adapter """
        from diplomacy_research.utils.tensorflow import tf
        hparams = self.parse_flags(load_args())

        # Generating model
        dataset = QueueDataset(batch_size=32,
                               dataset_builder=BaseDatasetBuilder())
        model = PolicyModel(dataset, hparams)
        model.finalize_build()
        model.add_meta_information(
            {'state_value': model.outputs['logits'][:, 0, 0]})
        self.adapter = PolicyAdapter(dataset, self.graph,
                                     tf.Session(graph=self.graph))
        self.create_advantage()

        # Setting cache path
        filename = '%s_savedgame.pbz' % self.model_type
        self.saved_game_cache_path = os.path.join(HOME_DIR, '.cache',
                                                  'diplomacy', filename)