Ejemplo n.º 1
0
def default_env_constructor(players, hparams, power_assignments,
                            set_player_seed, initial_state_bytes):
    """ Default gym environment constructor
        :param players: A list of instantiated players
        :param hparams: A dictionary of hyper parameters with their values
        :param power_assignments: Optional. The power name we want to play as. (e.g. 'FRANCE') or a list of powers.
        :param set_player_seed: Boolean that indicates that we want to set the player seed on reset().
        :param initial_state_bytes: A `game.State` proto (in bytes format) representing the initial state of the game.
        :rtype: diplomacy_research.models.gym.wrappers.DiplomacyWrapper
    """
    # The keys should be present in hparams if it's not None
    max_nb_years = hparams['max_nb_years'] if hparams else 35
    auto_draw = hparams['auto_draw'] if hparams else False
    power = hparams['power'] if hparams else ''
    nb_thrashing_states = hparams['nb_thrashing_states'] if hparams else 0

    env = gym.make('DiplomacyEnv-v0')
    env = LimitNumberYears(env, max_nb_years)

    # Loop Detection (Thrashing)
    if nb_thrashing_states:
        env = LoopDetection(env, threshold=nb_thrashing_states)

    # Auto-Drawing
    if auto_draw:
        env = AutoDraw(env)

    # Setting initial state
    if initial_state_bytes is not None:
        env = SetInitialState(env,
                              bytes_to_proto(initial_state_bytes, StateProto))

    # 1) If power_assignments is a list, using that to assign powers
    if isinstance(power_assignments, list):
        env = AssignPlayers(env, players, power_assignments)

    # 2a) If power_assignments is a string, using that as our power, and randomly assigning the other powers
    # 2b) Using hparams['power'] if it's set.
    elif isinstance(power_assignments, str) or power:
        our_power = power_assignments if isinstance(power_assignments,
                                                    str) else power
        other_powers = [
            power_name for power_name in env.get_all_powers_name()
            if power_name != our_power
        ]
        shuffle(other_powers)
        env = AssignPlayers(env, players, [our_power] + other_powers)

    # 3) Randomly shuffle the powers
    else:
        env = RandomizePlayers(env, players)

    # Setting player seed
    if set_player_seed:
        env = SetPlayerSeed(env, players)

    # Ability to save game / retrieve saved game
    env = SaveGame(env)
    return env
Ejemplo n.º 2
0
def data_generator(saved_game_bytes, is_validation_set):
    """ Converts a dataset game to protocol buffer format
        :param saved_game_bytes: A `.proto.game.SavedGame` object from the dataset.
        :param is_validation_set: Boolean that indicates if we are generating the validation set (otw. training set)
        :return: A dictionary with phase_ix as key and a dictionary {power_name: (msg_len, proto)} as value
    """
    saved_game_proto = bytes_to_proto(saved_game_bytes, SavedGameProto)
    if not keep_phase_in_dataset(saved_game_proto):
        return {phase_ix: [] for phase_ix, _ in enumerate(saved_game_proto.phases)}
    return root_data_generator(saved_game_proto, is_validation_set)
Ejemplo n.º 3
0
def build():
    """ Building the Redis dataset """
    if not os.path.exists(PROTO_DATASET_PATH):
        raise RuntimeError('Unable to find the proto dataset at %s' % PROTO_DATASET_PATH)

    # Creating output directory if it doesn't exist
    os.makedirs(os.path.join(WORKING_DIR, 'containers', 'redis'), exist_ok=True)

    # Starting the Redis server and blocking on that thread
    redis_thread = Thread(target=start_redis, kwargs={'save_dir': os.path.join(WORKING_DIR, 'containers'),
                                                      'log_file_path': os.devnull,
                                                      'clear': True})
    redis_thread.start()

    # Creating a memory buffer object to save games in Redis
    memory_buffer = MemoryBuffer()
    memory_buffer.clear()

    # Loading the phases count dataset to get the number of games
    total = None
    if os.path.exists(PHASES_COUNT_DATASET_PATH):
        with open(PHASES_COUNT_DATASET_PATH, 'rb') as file:
            total = len(pickle.load(file))
    progress_bar = tqdm(total=total)

    # Loading dataset and converting
    LOGGER.info('... Creating redis dataset.')
    with open(PROTO_DATASET_PATH, 'rb') as file:
        while True:
            saved_game_bytes = read_next_bytes(file)
            if saved_game_bytes is None:
                break
            progress_bar.update(1)
            saved_game_proto = bytes_to_proto(saved_game_bytes, SavedGameProto)
            save_expert_games(memory_buffer, [bytes_to_zlib(saved_game_bytes)], [saved_game_proto.id])

    # Saving
    memory_buffer.save(sync=True)

    # Moving file
    redis_db_path = {True: '/work_dir/redis/saved_redis.rdb',
                     False: os.path.join(WORKING_DIR, 'containers', 'redis', 'saved_redis.rdb')}.get(IN_PRODUCTION)
    shutil.move(redis_db_path, REDIS_DATASET_PATH)
    LOGGER.info('... Done creating redis dataset.')

    # Stopping Redis and thread
    progress_bar.close()
    memory_buffer.shutdown()
    redis_thread.join(timeout=60)
Ejemplo n.º 4
0
def save_games(buffer, saved_games_proto=None, saved_games_bytes=None):
    """ Stores a series of games in compressed saved game proto format
        :param buffer: An instance of the memory buffer.
        :param saved_games_bytes: List of saved game (bytes format) (either completed or partial)
        :param saved_games_proto: List of saved game proto (either completed or partial)
        :return: Nothing
        :type buffer: diplomacy_research.models.training.memory_buffer.MemoryBuffer
    """
    assert bool(saved_games_bytes is None) != bool(
        saved_games_proto is None), 'Expected one of bytes or proto'
    saved_games_bytes = saved_games_bytes or []
    saved_games_proto = saved_games_proto or []

    if saved_games_bytes:
        saved_games_proto = [
            bytes_to_proto(game_bytes, SavedGameProto)
            for game_bytes in saved_games_bytes
        ]

    # Splitting between completed game ids and partial game ids
    completed, partial = [], []
    for saved_game_proto in saved_games_proto:
        if not saved_game_proto.is_partial_game:
            completed += [saved_game_proto]
        else:
            partial += [saved_game_proto]

    # No games
    if not completed and not partial:
        LOGGER.warning(
            'Trying to save saved_games_proto, but no games provided. Skipping.'
        )
        return

    # Compressing games
    completed_games_zlib = [
        proto_to_zlib(saved_game_proto) for saved_game_proto in completed
    ]
    completed_game_ids = [
        saved_game_proto.id for saved_game_proto in completed
    ]

    partial_games_zlib = [
        proto_to_zlib(saved_game_proto) for saved_game_proto in partial
    ]
    partial_game_ids = [
        __PARTIAL_GAME_ID__ %
        (saved_game_proto.id, saved_game_proto.phases[-1].name)
        for saved_game_proto in partial
    ]

    # Saving games
    pipeline = buffer.redis.pipeline()
    if completed_game_ids:
        for game_id, saved_game_zlib in zip(completed_game_ids,
                                            completed_games_zlib):
            pipeline.set(__ONLINE_GAME__ % game_id, saved_game_zlib)
        pipeline.sadd(__SET_ONLINE_GAMES__, *completed_game_ids)

    if partial_game_ids:
        for game_id, saved_game_zlib in zip(partial_game_ids,
                                            partial_games_zlib):
            pipeline.set(__PARTIAL_GAME__ % game_id, saved_game_zlib)
        pipeline.sadd(__SET_PARTIAL_GAMES__, *partial_game_ids)

    # Executing
    pipeline.execute()
Ejemplo n.º 5
0
    def generate_proto_files(self, validation_perc=VALIDATION_SET_SPLIT):
        """ Generates train and validation protocol buffer files
            :param validation_perc: The percentage of the dataset to use to generate the validation dataset.
        """
        # pylint: disable=too-many-nested-blocks,too-many-statements,too-many-branches
        from diplomacy_research.utils.tensorflow import tf

        # The protocol buffers files have already been generated
        if os.path.exists(self.training_dataset_path) \
                and os.path.exists(self.validation_dataset_path) \
                and os.path.exists(self.dataset_index_path):
            return

        # Deleting the files if they exist
        shutil.rmtree(self.training_dataset_path, ignore_errors=True)
        shutil.rmtree(self.validation_dataset_path, ignore_errors=True)
        shutil.rmtree(self.dataset_index_path, ignore_errors=True)

        # Making sure the proto_generation_callable is callable
        proto_callable = self.proto_generation_callable
        assert callable(
            proto_callable
        ), "The proto_generable_callable must be a callable function"

        # Loading index
        dataset_index = {}

        # Splitting games into training and validation
        with open(PHASES_COUNT_DATASET_PATH, 'rb') as file:
            game_ids = list(pickle.load(file).keys())
        nb_valid = int(validation_perc * len(game_ids))
        set_valid_ids = set(sorted(game_ids)[-nb_valid:])
        train_ids = list(
            sorted([
                game_id for game_id in game_ids if game_id not in set_valid_ids
            ]))
        valid_ids = list(sorted(set_valid_ids))

        # Building a list of all task phases
        train_task_phases, valid_task_phases = [], []

        # Case 1) - We just build a list of game_ids with all their phase_ids (sorted by game id asc)
        if not self.sort_dataset_by_phase:
            with open(PHASES_COUNT_DATASET_PATH, 'rb') as phase_count_dataset:
                phase_count_dataset = pickle.load(phase_count_dataset)
                for game_id in train_ids:
                    for phase_ix in range(phase_count_dataset.get(game_id, 0)):
                        train_task_phases += [(phase_ix, game_id)]
                for game_id in valid_ids:
                    for phase_ix in range(phase_count_dataset.get(game_id, 0)):
                        valid_task_phases += [(phase_ix, game_id)]

        # Case 2) - We build 10 groups sorted by game_id asc, and for each group we put all phase_ix 0, then
        #           all phase_ix 1, ...
        #         - We need to split into groups so that the model can learn a mix of game beginning and endings
        #           otherwise, the model will only see beginnings and after 1 day only endings
        else:
            with open(PHASES_COUNT_DATASET_PATH, 'rb') as phase_count_dataset:
                phase_count_dataset = pickle.load(phase_count_dataset)
                nb_groups = 10
                nb_items_per_group = math.ceil(len(train_ids) / nb_groups)

                # Training
                for group_ix in range(nb_groups):
                    group_game_ids = train_ids[group_ix *
                                               nb_items_per_group:(group_ix +
                                                                   1) *
                                               nb_items_per_group]
                    group_task_phases = []
                    for game_id in group_game_ids:
                        for phase_ix in range(
                                phase_count_dataset.get(game_id, 0)):
                            group_task_phases += [(phase_ix, game_id)]
                    train_task_phases += list(sorted(group_task_phases))

                # Validation
                for game_id in valid_ids:
                    for phase_ix in range(phase_count_dataset.get(game_id, 0)):
                        valid_task_phases += [(phase_ix, game_id)]
                valid_task_phases = list(sorted(valid_task_phases))

        # Grouping tasks by buckets
        # buckets_pending contains a set of pending task ids in each bucket
        # buckets_keys contains a list of tuples so we can group items with similar message length together
        task_to_bucket = {}  # {task_id: bucket_id}
        train_buckets_pending, valid_buckets_pending = [], [
        ]  # [bucket_id: set()]
        train_buckets_keys, valid_buckets_keys = [], [
        ]  # [bucket_id: (msg_len, task_id, power_name)]
        if self.group_by_message_length:
            nb_valid_buckets = int(VALIDATION_SET_SPLIT * 100)
            nb_train_buckets = 100

            # Train buckets
            task_id = 1
            nb_items_per_bucket = math.ceil(
                len(train_task_phases) / nb_train_buckets)
            for bucket_ix in range(nb_train_buckets):
                items = train_task_phases[bucket_ix *
                                          nb_items_per_bucket:(bucket_ix + 1) *
                                          nb_items_per_bucket]
                nb_items = len(items)
                train_buckets_pending.append(set())
                train_buckets_keys.append([])

                for _ in range(nb_items):
                    train_buckets_pending[bucket_ix].add(task_id)
                    task_to_bucket[task_id] = bucket_ix
                    task_id += 1

            # Valid buckets
            task_id = -1
            nb_items_per_bucket = math.ceil(
                len(valid_task_phases) / nb_valid_buckets)
            for bucket_ix in range(nb_valid_buckets):
                items = valid_task_phases[bucket_ix *
                                          nb_items_per_bucket:(bucket_ix + 1) *
                                          nb_items_per_bucket]
                nb_items = len(items)
                valid_buckets_pending.append(set())
                valid_buckets_keys.append([])

                for _ in range(nb_items):
                    valid_buckets_pending[bucket_ix].add(task_id)
                    task_to_bucket[task_id] = bucket_ix
                    task_id -= 1

        # Building a dictionary of {game_id: {phase_ix: task_id}}
        # Train tasks have a id >= 0, valid tasks have an id < 0
        task_id = 1
        task_id_per_game = {}
        for phase_ix, game_id in train_task_phases:
            task_id_per_game.setdefault(game_id, {})[phase_ix] = task_id
            task_id += 1

        task_id = -1
        for phase_ix, game_id in valid_task_phases:
            task_id_per_game.setdefault(game_id, {})[phase_ix] = task_id
            task_id -= 1

        # Building a dictionary of pending items, so we can write them to disk in the correct order
        nb_train_tasks = len(train_task_phases)
        nb_valid_tasks = len(valid_task_phases)
        pending_train_tasks = OrderedDict(
            {task_id: None
             for task_id in range(1, nb_train_tasks + 1)})
        pending_valid_tasks = OrderedDict(
            {task_id: None
             for task_id in range(1, nb_valid_tasks + 1)})

        # Computing batch_size, progress bar and creating a pool of processes
        batch_size = 5120
        progress_bar = tqdm(total=nb_train_tasks + nb_valid_tasks)
        process_pool = ProcessPoolExecutor()
        futures = set()

        # Creating buffer to write all protos to disk at once
        train_buffer, valid_buffer = Queue(), Queue()

        # Opening the proto file to read games
        proto_dataset = open(PROTO_DATASET_PATH, 'rb')
        nb_items_being_processed = 0

        # Creating training and validation dataset
        for training_mode in ['train', 'valid']:
            next_key = 1
            current_bucket = 0

            if training_mode == 'train':
                pending_tasks = pending_train_tasks
                buckets_pending = train_buckets_pending
                buckets_keys = train_buckets_keys
                buffer = train_buffer
                max_next_key = nb_train_tasks + 1
            else:
                pending_tasks = pending_valid_tasks
                buckets_pending = valid_buckets_pending
                buckets_keys = valid_buckets_keys
                buffer = valid_buffer
                max_next_key = nb_valid_tasks + 1
            dataset_index['size_{}_dataset'.format(training_mode)] = 0

            # Processing with a queue to avoid high memory usage
            while pending_tasks:

                # Filling queues
                while batch_size > nb_items_being_processed:
                    saved_game_bytes = read_next_bytes(proto_dataset)
                    if saved_game_bytes is None:
                        break
                    saved_game_proto = bytes_to_proto(saved_game_bytes,
                                                      SavedGameProto)
                    game_id = saved_game_proto.id
                    if game_id not in task_id_per_game:
                        continue
                    nb_phases = len(saved_game_proto.phases)
                    task_ids = [
                        task_id_per_game[game_id][phase_ix]
                        for phase_ix in range(nb_phases)
                    ]
                    futures.add(
                        (tuple(task_ids),
                         process_pool.submit(handle_queues, task_ids,
                                             proto_callable, saved_game_bytes,
                                             task_ids[0] < 0)))
                    nb_items_being_processed += nb_phases

                # Processing results
                for expected_task_ids, future in list(futures):
                    if not future.done():
                        continue
                    results = future.result()
                    current_task_ids = set()

                    # Storing in compressed format in memory
                    for task_id, power_name, message_lengths, proto_result in results:
                        current_task_ids.add(task_id)

                        if proto_result is not None:
                            zlib_result = proto_to_zlib(proto_result)
                            if task_id > 0:
                                if pending_train_tasks[abs(task_id)] is None:
                                    pending_train_tasks[abs(task_id)] = {}
                                pending_train_tasks[abs(
                                    task_id)][power_name] = zlib_result
                            else:
                                if pending_valid_tasks[abs(task_id)] is None:
                                    pending_valid_tasks[abs(task_id)] = {}
                                pending_valid_tasks[abs(
                                    task_id)][power_name] = zlib_result

                            if self.group_by_message_length:
                                task_bucket_id = task_to_bucket[task_id]
                                if task_id > 0:
                                    train_buckets_keys[task_bucket_id].append(
                                        (message_lengths, task_id, power_name))
                                else:
                                    valid_buckets_keys[task_bucket_id].append(
                                        (message_lengths, task_id, power_name))

                        # No results - Marking task id as done
                        elif task_id > 0 and pending_train_tasks[abs(
                                task_id)] is None:
                            del pending_train_tasks[abs(task_id)]
                        elif task_id < 0 and pending_valid_tasks[abs(
                                task_id)] is None:
                            del pending_valid_tasks[abs(task_id)]

                    # Missing some task ids
                    if set(expected_task_ids) != current_task_ids:
                        LOGGER.warning(
                            'Missing tasks ids. Got %s - Expected: %s',
                            current_task_ids, expected_task_ids)
                        current_task_ids = expected_task_ids

                    # Marking tasks as completed
                    nb_items_being_processed -= len(expected_task_ids)
                    progress_bar.update(len(current_task_ids))

                    # Marking items as not pending in buckets
                    if self.group_by_message_length:
                        for task_id in current_task_ids:
                            task_bucket_id = task_to_bucket[task_id]
                            if task_id > 0:
                                train_buckets_pending[task_bucket_id].remove(
                                    task_id)
                            else:
                                valid_buckets_pending[task_bucket_id].remove(
                                    task_id)

                    # Deleting futures to release memory
                    futures.remove((expected_task_ids, future))
                    del future

                # Writing to disk
                while True:
                    if self.group_by_message_length:

                        # Done all buckets
                        if current_bucket >= len(buckets_pending):
                            break

                        # Still waiting for tasks in the current bucket
                        if buckets_pending[current_bucket]:
                            break

                        # Bucket was empty - We can look at next bucket
                        if not buckets_keys[current_bucket]:
                            current_bucket += 1
                            break

                        # Sorting items in bucket before writing them in buffer
                        items_in_bucket = list(
                            sorted(buckets_keys[current_bucket]))
                        for _, task_id, power_name in items_in_bucket:
                            zlib_result = pending_tasks[abs(
                                task_id)][power_name]
                            buffer.put(zlib_result)
                            dataset_index['size_{}_dataset'.format(
                                training_mode)] += 1
                            del pending_tasks[abs(task_id)][power_name]
                            if not pending_tasks[abs(task_id)]:
                                del pending_tasks[abs(task_id)]
                        current_bucket += 1
                        del items_in_bucket
                        break

                    # Writing to buffer in the same order as they are received
                    if next_key >= max_next_key:
                        break
                    if next_key not in pending_tasks:
                        next_key += 1
                        continue
                    if pending_tasks[next_key] is None:
                        break
                    zlib_results = pending_tasks.pop(next_key)
                    for zlib_result in zlib_results.values():
                        buffer.put(zlib_result)
                        dataset_index['size_{}_dataset'.format(
                            training_mode)] += 1
                    next_key += 1
                    del zlib_results

        # Stopping pool, and progress bar
        process_pool.shutdown(wait=True)
        progress_bar.close()
        proto_dataset.close()

        # Storing protos to disk
        LOGGER.info('Writing protos to disk...')
        progress_bar = tqdm(total=train_buffer.qsize() + valid_buffer.qsize())
        options = tf.io.TFRecordOptions(
            compression_type=tf.io.TFRecordCompressionType.GZIP)

        with tf.io.TFRecordWriter(self.training_dataset_path,
                                  options=options) as dataset_writer:
            while not train_buffer.empty():
                zlib_result = train_buffer.get()
                dataset_writer.write(zlib_to_bytes(zlib_result))
                progress_bar.update(1)

        with tf.io.TFRecordWriter(self.validation_dataset_path,
                                  options=options) as dataset_writer:
            while not valid_buffer.empty():
                zlib_result = valid_buffer.get()
                dataset_writer.write(zlib_to_bytes(zlib_result))
                progress_bar.update(1)

        with open(self.dataset_index_path, 'wb') as dataset_index_file:
            pickle.dump(dataset_index, dataset_index_file,
                        pickle.HIGHEST_PROTOCOL)

        # Closing
        progress_bar.close()
Ejemplo n.º 6
0
def test_to_from_bytes():
    """ Tests proto_to_bytes and bytes_to_proto """
    message_proto = _get_message()
    message_bytes = proto_to_bytes(message_proto)
    new_message_proto = bytes_to_proto(message_bytes, Message)
    _compare_messages(message_proto, new_message_proto)