def test_board_state(): """ Tests the proto_to_state_space """ game = Game() game_map = game.map state_proto = state_space.extract_state_proto(game) new_game = state_space.build_game_from_state_proto(state_proto) # Retrieving board_state state_proto_2 = state_space.extract_state_proto(new_game) board_state_1 = state_space.proto_to_board_state(state_proto, game_map) board_state_2 = state_space.proto_to_board_state(state_proto_2, game_map) # Checking assert np.allclose(board_state_1, board_state_2) assert board_state_1.shape == (state_space.NB_NODES, state_space.NB_FEATURES) assert game.get_hash() == new_game.get_hash()
def get_feedable_item(locs, state_proto, power_name, phase_history_proto, possible_orders_proto, **kwargs): """ Computes and return a feedable item (to be fed into the feedable queue) :param locs: A list of locations for which we want orders :param state_proto: A `.proto.game.State` representation of the state of the game. :param power_name: The power name for which we want the orders and the state values :param phase_history_proto: A list of `.proto.game.PhaseHistory`. This represents prev phases. :param possible_orders_proto: A `proto.game.PossibleOrders` object representing possible order for each loc. :param kwargs: Additional optional kwargs: - player_seed: The seed to apply to the player to compute a deterministic mask. - noise: The sigma of the additional noise to apply to the intermediate layers (i.e. sigma * epsilon) - temperature: The temperature to apply to the logits. (Default to 0. for deterministic/greedy) - dropout_rate: The amount of dropout to apply to the inputs/outputs of the decoder. :return: A feedable item, with feature names as key and numpy arrays as values """ # pylint: disable=too-many-branches # Converting to state space map_object = Map(state_proto.map) board_state = proto_to_board_state(state_proto, map_object) # Building the decoder length # For adjustment phase, we restrict the number of builds/disbands to what is allowed by the game engine in_adjustment_phase = state_proto.name[-1] == 'A' nb_builds = state_proto.builds[power_name].count nb_homes = len(state_proto.builds[power_name].homes) # If we are in adjustment phase, making sure the locs are the orderable locs (and not the policy locs) if in_adjustment_phase: orderable_locs, _ = get_orderable_locs_for_powers(state_proto, [power_name]) if sorted(locs) != sorted(orderable_locs): if locs: LOGGER.warning('Adj. phase requires orderable locs. Got %s. Expected %s.', locs, orderable_locs) locs = orderable_locs # WxxxA - We can build units # WxxxA - We can disband units # Other phase if in_adjustment_phase and nb_builds >= 0: decoder_length = min(nb_builds, nb_homes) elif in_adjustment_phase and nb_builds < 0: decoder_length = abs(nb_builds) else: decoder_length = len(locs) # Computing the candidates for the policy if possible_orders_proto: # Adjustment Phase - Use all possible orders for each location. if in_adjustment_phase: # Building a list of all orders for all locations adj_orders = [] for loc in locs: adj_orders += possible_orders_proto[loc].value # Computing the candidates candidates = [get_order_based_mask(adj_orders)] * decoder_length # Regular phase - Compute candidates for each location else: candidates = [] for loc in locs: candidates += [get_order_based_mask(possible_orders_proto[loc].value)] # We don't have possible orders, so we cannot compute candidates # This might be normal if we are only getting the state value or the next message to send else: candidates = [] for _ in range(decoder_length): candidates.append([]) # Prev orders state prev_orders_state = [] for phase_proto in reversed(phase_history_proto): if len(prev_orders_state) == NB_PREV_ORDERS: break if phase_proto.name[-1] == 'M': prev_orders_state = [proto_to_prev_orders_state(phase_proto, map_object)] + prev_orders_state for _ in range(NB_PREV_ORDERS - len(prev_orders_state)): prev_orders_state = [np.zeros((NB_NODES, NB_ORDERS_FEATURES), dtype=np.uint8)] + prev_orders_state prev_orders_state = np.array(prev_orders_state) # Building (order) decoder inputs [GO_ID] decoder_inputs = [GO_ID] # kwargs player_seed = kwargs.get('player_seed', 0) noise = kwargs.get('noise', 0.) temperature = kwargs.get('temperature', 0.) dropout_rate = kwargs.get('dropout_rate', 0.) # Building feedable data item = { 'player_seed': player_seed, 'board_state': board_state, 'board_alignments': get_board_alignments(locs, in_adjustment_phase=in_adjustment_phase, tokens_per_loc=1, decoder_length=decoder_length), 'prev_orders_state': prev_orders_state, 'decoder_inputs': decoder_inputs, 'decoder_lengths': decoder_length, 'candidates': candidates, 'noise': noise, 'temperature': temperature, 'dropout_rate': dropout_rate, 'current_power': POWER_VOCABULARY_KEY_TO_IX[power_name], 'current_season': get_current_season(state_proto) } # Return return item
def get_policy_data(saved_game_proto, power_names, top_victors): """ Computes the proto to save in tf.train.Example as a training example for the policy network :param saved_game_proto: A `.proto.game.SavedGame` object from the dataset. :param power_names: The list of powers for which we want the policy data :param top_victors: The list of powers that ended with more than 25% of the supply centers :return: A dictionary with key: the phase_ix with value: A dict with the power_name as key and a dict with the example fields as value """ nb_phases = len(saved_game_proto.phases) policy_data = {phase_ix: {} for phase_ix in range(nb_phases - 1)} game_id = saved_game_proto.id map_object = Map(saved_game_proto.map) # Determining if we have a draw nb_sc_to_win = len(map_object.scs) // 2 + 1 has_solo_winner = max([len(saved_game_proto.phases[-1].state.centers[power_name].value) for power_name in saved_game_proto.phases[-1].state.centers]) >= nb_sc_to_win survivors = [power_name for power_name in saved_game_proto.phases[-1].state.centers if saved_game_proto.phases[-1].state.centers[power_name].value] has_draw = not has_solo_winner and len(survivors) >= 2 # Processing all phases (except the last one) current_year = 0 for phase_ix in range(nb_phases - 1): # Building a list of orders of previous phases previous_orders_states = [np.zeros((NB_NODES, NB_ORDERS_FEATURES), dtype=np.uint8)] * NB_PREV_ORDERS for phase_proto in saved_game_proto.phases[max(0, phase_ix - NB_PREV_ORDERS_HISTORY):phase_ix]: if phase_proto.name[-1] == 'M': previous_orders_states += [proto_to_prev_orders_state(phase_proto, map_object)] previous_orders_states = previous_orders_states[-NB_PREV_ORDERS:] prev_orders_state = np.array(previous_orders_states) # Parsing each requested power in the specified phase phase_proto = saved_game_proto.phases[phase_ix] phase_name = phase_proto.name state_proto = phase_proto.state phase_board_state = proto_to_board_state(state_proto, map_object) # Increasing year for every spring or when the game is completed if phase_proto.name == 'COMPLETED' or (phase_proto.name[0] == 'S' and phase_proto.name[-1] == 'M'): current_year += 1 for power_name in power_names: phase_issued_orders = get_issued_orders_for_powers(phase_proto, [power_name]) phase_possible_orders = get_possible_orders_for_powers(phase_proto, [power_name]) phase_draw_target = 1. if has_draw and phase_ix == (nb_phases - 2) and power_name in survivors else 0. # Data to use when not learning a policy blank_policy_data = {'board_state': phase_board_state, 'prev_orders_state': prev_orders_state, 'draw_target': phase_draw_target} # Power is not a top victor - We don't want to learn a policy from him if power_name not in top_victors: policy_data[phase_ix][power_name] = blank_policy_data continue # Finding the orderable locs orderable_locations = list(phase_issued_orders[power_name].keys()) # Skipping power for this phase if we are only issuing Hold for order_loc, order in phase_issued_orders[power_name].items(): order_tokens = get_order_tokens(order) if len(order_tokens) >= 2 and order_tokens[1] != 'H': break else: policy_data[phase_ix][power_name] = blank_policy_data continue # Removing orderable locs where orders are not possible (i.e. NO_CHECK games) for order_loc, order in phase_issued_orders[power_name].items(): if order not in phase_possible_orders[order_loc] and order_loc in orderable_locations: if 'NO_CHECK' not in saved_game_proto.rules: LOGGER.warning('%s not in all possible orders. Phase %s - Game %s.', order, phase_name, game_id) orderable_locations.remove(order_loc) # Remove orderable locs where the order is either invalid or not frequent if order_to_ix(order) is None and order_loc in orderable_locations: orderable_locations.remove(order_loc) # Determining if we are in an adjustment phase in_adjustment_phase = state_proto.name[-1] == 'A' nb_builds = state_proto.builds[power_name].count nb_homes = len(state_proto.builds[power_name].homes) # WxxxA - We can build units # WxxxA - We can disband units # Other phase if in_adjustment_phase and nb_builds >= 0: decoder_length = min(nb_builds, nb_homes) elif in_adjustment_phase and nb_builds < 0: decoder_length = abs(nb_builds) else: decoder_length = len(orderable_locations) # Not all units were disbanded - Skipping this power as we can't learn the orders properly if in_adjustment_phase and nb_builds < 0 and len(orderable_locations) < abs(nb_builds): policy_data[phase_ix][power_name] = blank_policy_data continue # Not enough orderable locations for this power, skipping if not orderable_locations or not decoder_length: policy_data[phase_ix][power_name] = blank_policy_data continue # decoder_inputs [GO, order1, order2, order3] decoder_inputs = [GO_ID] decoder_inputs += [order_to_ix(phase_issued_orders[power_name][loc]) for loc in orderable_locations] if in_adjustment_phase and nb_builds > 0: decoder_inputs += [order_to_ix('WAIVE')] * (min(nb_builds, nb_homes) - len(orderable_locations)) decoder_length = min(decoder_length, NB_SUPPLY_CENTERS) # Adjustment Phase - Use all possible orders for each location. if in_adjustment_phase: build_disband_locs = list(get_possible_orders_for_powers(phase_proto, [power_name]).keys()) phase_board_alignments = get_board_alignments(build_disband_locs, in_adjustment_phase=in_adjustment_phase, tokens_per_loc=1, decoder_length=decoder_length) # Building a list of all orders for all locations adj_orders = [] for loc in build_disband_locs: adj_orders += phase_possible_orders[loc] # Not learning builds for BUILD_ANY if nb_builds > 0 and 'BUILD_ANY' in state_proto.rules: adj_orders = [] # No orders found - Skipping if not adj_orders: policy_data[phase_ix][power_name] = blank_policy_data continue # Computing the candidates candidates = [get_order_based_mask(adj_orders)] * decoder_length # Regular phase - Compute candidates for each location else: phase_board_alignments = get_board_alignments(orderable_locations, in_adjustment_phase=in_adjustment_phase, tokens_per_loc=1, decoder_length=decoder_length) candidates = [] for loc in orderable_locations: candidates += [get_order_based_mask(phase_possible_orders[loc])] # Saving results # No need to return temperature, current_power, current_season policy_data[phase_ix][power_name] = {'board_state': phase_board_state, 'board_alignments': phase_board_alignments, 'prev_orders_state': prev_orders_state, 'decoder_inputs': decoder_inputs, 'decoder_lengths': decoder_length, 'candidates': candidates, 'draw_target': phase_draw_target} # Returning return policy_data
def generate_trajectory(players, reward_fn, advantage_fn, env_constructor=None, hparams=None, power_assignments=None, set_player_seed=None, initial_state_bytes=None, update_interval=0, update_queue=None, output_format='proto'): """ Generates a single trajectory (Saved Gamed Proto) for RL (self-play) with the power assigments :param players: A list of instantiated players :param reward_fn: The reward function to use to calculate rewards :param advantage_fn: An instance of `.models.self_play.advantages` :param env_constructor: A callable to get the OpenAI gym environment (args: players) :param hparams: A dictionary of hyper parameters with their values :param power_assignments: Optional. The power name we want to play as. (e.g. 'FRANCE') or a list of powers. :param set_player_seed: Boolean that indicates that we want to set the player seed on reset(). :param initial_state_bytes: A `game.State` proto (in bytes format) representing the initial state of the game. :param update_interval: Optional. If set, a partial saved game is put in the update_queue this every seconds. :param update_queue: Optional. If update interval is set, partial games will be put in this queue :param output_format: The output format. One of 'proto', 'bytes', 'zlib' :return: A SavedGameProto representing the game played (with policy details and power assignments) Depending on format, the output might be converted to a byte array, or a compressed byte array. :type players: List[diplomacy_research.players.player.Player] :type reward_fn: diplomacy_research.models.self_play.reward_functions.AbstractRewardFunction :type advantage_fn: diplomacy_research.models.self_play.advantages.base_advantage.BaseAdvantage :type update_queue: multiprocessing.Queue """ # pylint: disable=too-many-arguments assert output_format in ['proto', 'bytes', 'zlib' ], 'Format should be "proto", "bytes", "zlib"' assert len(players) == NB_POWERS # Making sure we use the SavedGame wrapper to record the game if env_constructor: env = env_constructor(players) else: env = default_env_constructor(players, hparams, power_assignments, set_player_seed, initial_state_bytes) wrapped_env = env while not isinstance(wrapped_env, DiplomacyEnv): if isinstance(wrapped_env, SaveGame): break wrapped_env = wrapped_env.env else: env = SaveGame(env) # Detecting if we have a Auto-Draw wrapper has_auto_draw = False wrapped_env = env while not isinstance(wrapped_env, DiplomacyEnv): if isinstance(wrapped_env, AutoDraw): has_auto_draw = True break wrapped_env = wrapped_env.env # Resetting env env.reset() # Timing vars for partial updates time_last_update = time.time() year_last_update = 0 start_phase_ix = 0 current_phase_ix = 0 nb_transitions = 0 # Cache Variables powers = sorted( [power_name for power_name in get_map_powers(env.game.map)]) assigned_powers = env.get_all_powers_name() stored_board_state = OrderedDict() # {phase_name: board_state} stored_prev_orders_state = OrderedDict() # {phase_name: prev_orders_state} stored_possible_orders = OrderedDict() # {phase_name: possible_orders} power_variables = { power_name: { 'orders': [], 'policy_details': [], 'state_values': [], 'rewards': [], 'returns': [], 'last_state_value': 0. } for power_name in powers } new_state_proto = None phase_history_proto = [] map_object = Map(name=env.game.map.name) # Generating while not env.is_done: state_proto = new_state_proto if new_state_proto is not None else extract_state_proto( env.game) possible_orders_proto = extract_possible_orders_proto(env.game) # Computing board_state board_state = proto_to_board_state(state_proto, map_object).flatten().tolist() state_proto.board_state.extend(board_state) # Storing possible orders for this phase current_phase = env.game.get_current_phase() stored_board_state[current_phase] = board_state stored_possible_orders[current_phase] = possible_orders_proto # Getting orders, policy details, and state value tasks = [(player, state_proto, pow_name, phase_history_proto[-NB_PREV_ORDERS_HISTORY:], possible_orders_proto) for player, pow_name in zip(env.players, assigned_powers)] step_args = yield [get_step_args(*args) for args in tasks] # Stepping through env, storing power variables for power_name, (orders, policy_details, state_value) in zip(assigned_powers, step_args): if orders: env.step((power_name, orders)) nb_transitions += 1 if has_auto_draw and policy_details is not None: env.set_draw_prob(power_name, policy_details['draw_prob']) # Processing env.process() current_phase_ix += 1 # Retrieving draw action and saving power variables for power_name, (orders, policy_details, state_value) in zip(assigned_powers, step_args): if has_auto_draw and policy_details is not None: policy_details['draw_action'] = env.get_draw_actions( )[power_name] power_variables[power_name]['orders'] += [orders] power_variables[power_name]['policy_details'] += [policy_details] power_variables[power_name]['state_values'] += [state_value] # Getting new state new_state_proto = extract_state_proto(env.game) # Storing reward for this transition done_reason = DoneReason(env.done_reason) if env.done_reason else None for power_name in powers: power_variables[power_name]['rewards'] += [ reward_fn.get_reward(prev_state_proto=state_proto, state_proto=new_state_proto, power_name=power_name, is_terminal_state=done_reason is not None, done_reason=done_reason) ] # Computing prev_orders_state for the previous state last_phase_proto = extract_phase_history_proto( env.game, nb_previous_phases=1)[-1] if last_phase_proto.name[-1] == 'M': prev_orders_state = proto_to_prev_orders_state( last_phase_proto, map_object).flatten().tolist() stored_prev_orders_state[last_phase_proto.name] = prev_orders_state last_phase_proto.prev_orders_state.extend(prev_orders_state) phase_history_proto += [last_phase_proto] # Sending partial game if: # 1) We have update_interval > 0 with an update queue, and # 2a) The game is completed, or 2b) the update time has elapsted and at least 5 years as passed has_update_interval = update_interval > 0 and update_queue is not None game_is_completed = env.is_done min_time_has_passed = time.time() - time_last_update > update_interval current_year = 9999 if env.game.get_current_phase( ) == 'COMPLETED' else int(env.game.get_current_phase()[1:5]) min_years_have_passed = current_year - year_last_update >= 5 if (has_update_interval and (game_is_completed or (min_time_has_passed and min_years_have_passed))): # Game is completed - last state value is 0 if game_is_completed: for power_name in powers: power_variables[power_name]['last_state_value'] = 0. # Otherwise - Querying the model for the value of the last state else: tasks = [ (player, new_state_proto, pow_name, phase_history_proto[-NB_PREV_ORDERS_HISTORY:], possible_orders_proto) for player, pow_name in zip(env.players, assigned_powers) ] last_state_values = yield [ get_state_value(*args) for args in tasks ] for power_name, last_state_value in zip( assigned_powers, last_state_values): power_variables[power_name][ 'last_state_value'] = last_state_value # Getting partial game and sending it on the update_queue saved_game_proto = get_saved_game_proto( env=env, players=players, stored_board_state=stored_board_state, stored_prev_orders_state=stored_prev_orders_state, stored_possible_orders=stored_possible_orders, power_variables=power_variables, start_phase_ix=start_phase_ix, reward_fn=reward_fn, advantage_fn=advantage_fn, is_partial_game=True) update_queue.put_nowait( (False, nb_transitions, proto_to_bytes(saved_game_proto))) # Updating stats start_phase_ix = current_phase_ix nb_transitions = 0 if not env.is_done: year_last_update = int(env.game.get_current_phase()[1:5]) # Since the environment is done (Completed game) - We can leave the last_state_value at 0. for power_name in powers: power_variables[power_name]['last_state_value'] = 0. # Getting completed game saved_game_proto = get_saved_game_proto( env=env, players=players, stored_board_state=stored_board_state, stored_prev_orders_state=stored_prev_orders_state, stored_possible_orders=stored_possible_orders, power_variables=power_variables, start_phase_ix=0, reward_fn=reward_fn, advantage_fn=advantage_fn, is_partial_game=False) # Converting to correct format output = { 'proto': lambda proto: proto, 'zlib': proto_to_zlib, 'bytes': proto_to_bytes }[output_format](saved_game_proto) # Returning return output
def get_policy_data(saved_game_proto, power_names, top_victors): """ Computes the proto to save in tf.train.Example as a training example for the policy network :param saved_game_proto: A `.proto.game.SavedGame` object from the dataset. :param power_names: The list of powers for which we want the policy data :param top_victors: The list of powers that ended with more than 25% of the supply centers :return: A dictionary with key: the phase_ix with value: A dict with the power_name as key and a dict with the example fields as value """ # pylint: disable=too-many-branches nb_phases = len(saved_game_proto.phases) policy_data = {phase_ix: {} for phase_ix in range(nb_phases - 1)} game_id = saved_game_proto.id map_object = Map(saved_game_proto.map) # Determining if we have a draw nb_sc_to_win = len(map_object.scs) // 2 + 1 has_solo_winner = max([len(saved_game_proto.phases[-1].state.centers[power_name].value) for power_name in saved_game_proto.phases[-1].state.centers]) >= nb_sc_to_win survivors = [power_name for power_name in saved_game_proto.phases[-1].state.centers if saved_game_proto.phases[-1].state.centers[power_name].value] has_draw = not has_solo_winner and len(survivors) >= 2 # Processing all phases (except the last one current_year = 0 for phase_ix in range(nb_phases - 1): # Building a list of orders of previous phases previous_orders_states = [np.zeros((NB_NODES, NB_ORDERS_FEATURES), dtype=np.uint8)] * NB_PREV_ORDERS for phase_proto in saved_game_proto.phases[max(0, phase_ix - NB_PREV_ORDERS_HISTORY):phase_ix]: if phase_proto.name[-1] == 'M': previous_orders_states += [proto_to_prev_orders_state(phase_proto, map_object)] previous_orders_states = previous_orders_states[-NB_PREV_ORDERS:] prev_orders_state = np.array(previous_orders_states) # Parsing each requested power in the specified phase phase_proto = saved_game_proto.phases[phase_ix] phase_name = phase_proto.name state_proto = phase_proto.state phase_board_state = proto_to_board_state(state_proto, map_object) # Increasing year for every spring or when the game is completed if phase_proto.name == 'COMPLETED' or (phase_proto.name[0] == 'S' and phase_proto.name[-1] == 'M'): current_year += 1 for power_name in power_names: phase_issued_orders = get_issued_orders_for_powers(phase_proto, [power_name]) phase_possible_orders = get_possible_orders_for_powers(phase_proto, [power_name]) phase_draw_target = 1. if has_draw and phase_ix == (nb_phases - 2) and power_name in survivors else 0. # Data to use when not learning a policy blank_policy_data = {'board_state': phase_board_state, 'prev_orders_state': prev_orders_state, 'draw_target': phase_draw_target} # Power is not a top victor - We don't want to learn a policy from him if power_name not in top_victors: policy_data[phase_ix][power_name] = blank_policy_data continue # Finding the orderable locs orderable_locations = list(phase_issued_orders[power_name].keys()) # Skipping power for this phase if we are only issuing Hold for order_loc, order in phase_issued_orders[power_name].items(): order_tokens = get_order_tokens(order) if len(order_tokens) >= 2 and order_tokens[1] != 'H': break else: policy_data[phase_ix][power_name] = blank_policy_data continue # Removing orderable locs where orders are not possible (i.e. NO_CHECK games) for order_loc, order in phase_issued_orders[power_name].items(): if order not in phase_possible_orders[order_loc]: if 'NO_CHECK' not in saved_game_proto.rules: LOGGER.warning('%s not in all possible orders. Phase %s - Game %s.', order, phase_name, game_id) orderable_locations.remove(order_loc) # Determining if we are in an adjustment phase in_adjustment_phase = state_proto.name[-1] == 'A' nb_builds = state_proto.builds[power_name].count nb_homes = len(state_proto.builds[power_name].homes) # WxxxA - We can build units # WxxxA - We can disband units # Other phase if in_adjustment_phase and nb_builds >= 0: decoder_length = TOKENS_PER_ORDER * min(nb_builds, nb_homes) elif in_adjustment_phase and nb_builds < 0: decoder_length = TOKENS_PER_ORDER * abs(nb_builds) else: decoder_length = TOKENS_PER_ORDER * len(orderable_locations) # Not all units were disbanded - Skipping this power as we can't learn the orders properly if in_adjustment_phase and nb_builds < 0 and len(orderable_locations) < abs(nb_builds): policy_data[phase_ix][power_name] = blank_policy_data continue # Not enough orderable locations for this power, skipping if not orderable_locations or not decoder_length: policy_data[phase_ix][power_name] = blank_policy_data continue # Encoding decoder inputs - Padding each order to 6 tokens # The decoder length should be a multiple of 6, since each order is padded to 6 tokens decoder_inputs = [GO_ID] for loc in orderable_locations[:]: order = phase_issued_orders[power_name][loc] try: tokens = [token_to_ix(order_token) for order_token in get_order_tokens(order)] + [EOS_ID] tokens += [PAD_ID] * (TOKENS_PER_ORDER - len(tokens)) decoder_inputs += tokens except KeyError: LOGGER.warning('[data_generator] Order "%s" is not valid. Skipping location.', order) orderable_locations.remove(loc) # Adding WAIVE orders if in_adjustment_phase and nb_builds > 0: waive_tokens = [token_to_ix('WAIVE'), EOS_ID] + [PAD_ID] * (TOKENS_PER_ORDER - 2) decoder_inputs += waive_tokens * (min(nb_builds, nb_homes) - len(orderable_locations)) decoder_length = min(decoder_length, TOKENS_PER_ORDER * NB_SUPPLY_CENTERS) # Getting decoder mask coords = set() # Adjustment phase, we allow all builds / disbands in all positions if in_adjustment_phase: build_disband_locs = list(get_possible_orders_for_powers(phase_proto, [power_name]).keys()) phase_board_alignments = get_board_alignments(build_disband_locs, in_adjustment_phase=in_adjustment_phase, tokens_per_loc=TOKENS_PER_ORDER, decoder_length=decoder_length) # Building a list of all orders for all locations adj_orders = [] for loc in build_disband_locs: adj_orders += phase_possible_orders[loc] # Not learning builds for BUILD_ANY if nb_builds > 0 and 'BUILD_ANY' in state_proto.rules: adj_orders = [] # No orders found - Skipping if not adj_orders: policy_data[phase_ix][power_name] = blank_policy_data continue # Building a list of coordinates for the decoder mask matrix for loc_ix in range(decoder_length): coords = get_token_based_mask(adj_orders, offset=loc_ix * TOKENS_PER_ORDER, coords=coords) # Regular phase, we mask for each location else: phase_board_alignments = get_board_alignments(orderable_locations, in_adjustment_phase=in_adjustment_phase, tokens_per_loc=TOKENS_PER_ORDER, decoder_length=decoder_length) for loc_ix, loc in enumerate(orderable_locations): coords = get_token_based_mask(phase_possible_orders[loc] or [''], offset=loc_ix * TOKENS_PER_ORDER, coords=coords) # Saving results # No need to return temperature, current_power, current_season policy_data[phase_ix][power_name] = {'board_state': phase_board_state, 'board_alignments': phase_board_alignments, 'prev_orders_state': prev_orders_state, 'decoder_inputs': decoder_inputs, 'decoder_lengths': decoder_length, 'decoder_mask_indices': list(sorted(coords)), 'draw_target': phase_draw_target} # Returning return policy_data