def policy_fn(state): """Generates a policy profile by treating sequence indices as weights.""" info_state = state.information_state_string() sequence_offset = root.info_state_to_sequence_idx[info_state] num_actions = len(state.legal_actions()) return rcfr.normalized_by_sum( list(range(sequence_offset, sequence_offset + num_actions)))
def test_normalized_by_sum(self): self.assertAllClose( rcfr.normalized_by_sum([1., 2., 3., 4.]), [0.1, 0.2, 0.3, 0.4])