def test_construct_grammar_one_term_at_a_time(self): grammar = grammar_lib.construct_grammar( grammar_config=grammar_lib.GrammarConfig( grammar_type=grammar_lib.GrammarType.ONE_TERM_AT_A_TIME, split_vocabulary_by_type=False), vocab={"car", "##ry", "in", "##ner"}) grammar = [str(production) for production in grammar.productions()] grammar_target = [ # Word recursion. "_W_ -> _Vw_", "_W_ -> _Vw_ _W-_", "_W-_ -> _Vsh_", "_W-_ -> _Vsh_ _W-_", # Terminal rules. "_Vw_ -> 'car'", "_Vw_ -> 'in'", "_Vsh_ -> '##ry'", "_Vsh_ -> '##ner'", # Q-rules. "_Q_ -> '[stop]'", # No more adjustments. "_Q_ -> '[pos]' '[contents]' _W_ _Q_", "_Q_ -> '[neg]' '[contents]' _W_ _Q_", "_Q_ -> '[pos]' '[title]' _W_ _Q_", "_Q_ -> '[neg]' '[title]' _W_ _Q_", ] self.assertEqual(set(grammar), set(grammar_target))
def test_construct_grammar_add_term_only(self): grammar = grammar_lib.construct_grammar( grammar_config=grammar_lib.GrammarConfig( grammar_type=grammar_lib.GrammarType.ADD_TERM_ONLY, split_vocabulary_by_type=False), vocab={"car", "##ry", "in", "##ner"}) grammar = [str(production) for production in grammar.productions()] grammar_target = [ # Word recursion. "_W_ -> _Vw_", "_W_ -> _Vw_ _W-_", "_W-_ -> _Vsh_", "_W-_ -> _Vsh_ _W-_", # Terminal rules. "_Vw_ -> 'car'", "_Vw_ -> 'in'", "_Vsh_ -> '##ry'", "_Vsh_ -> '##ner'", # Q-rules. "_Q_ -> '[stop]'", # No more adjustments. "_Q_ -> '[or]' _W_ _Q_", ] self.assertEqual(set(grammar), set(grammar_target))
def get_descriptor() -> mzcore.EnvironmentDescriptor: sequence_length = common_flags.BERT_SEQUENCE_LENGTH.value bert_config = configs.BertConfig.from_json_file( common_flags.BERT_CONFIG.value) grammar_config = grammar_lib.grammar_config_from_flags() max_len_type_vocab = max(map(len, bert_state_lib.TYPE_VOCABS.values())) tokenizer = bert_state_lib.get_tokenizer() grammar = grammar_lib.construct_grammar(grammar_config=grammar_config, vocab=list( tokenizer.get_vocab().keys())) observation_space = gym.spaces.Tuple([ gym.spaces.Box(0, bert_config.vocab_size, (sequence_length, ), np.int32), gym.spaces.Box(0, max_len_type_vocab, (sequence_length, len(bert_state_lib.TYPE_VOCABS)), np.int32), gym.spaces.Box(-np.inf, np.inf, (sequence_length, len(bert_state_lib.FLOAT_NAMES)), np.float32), gym.spaces.Box(0, len(grammar.productions()) + 1, (common_flags.N_ACTIONS_ENCODING.value, ), np.int32), # +1 for mask ]) max_episode_length = common_flags.MAX_NUM_ACTIONS.value # If you change rewards / add new rewards, make sure to update the bounds. min_possible_score, max_possible_score = { 'curiosity+dcg': (-1, 1), }.get(common_flags.REWARD.value, (-1, 1)) min_possible_cumulative_score, max_possible_cumulative_score = { 'curiosity+dcg': (-2, 2), }.get(common_flags.REWARD.value, (min_possible_score, max_possible_score)) logging.info('Max episode length: %d; Score range: [%.2f, %.2f]', max_episode_length, min_possible_score, max_possible_score) # Additional statistics that we want to track 'in' the learner learner_stats = ( ('ndcg_score', tf.float32), ('ndcg_score_improvement', tf.float32), ('em_at_1', tf.float32), ('em_at_1_improvement', tf.float32), (f'em_at_{int(common_flags.K.value)}', tf.float32), (f'em_at_{int(common_flags.K.value)}_improvement', tf.float32), (f'recall_at_{int(common_flags.K.value)}', tf.float32), (f'recall_at_{int(common_flags.K.value)}_improvement', tf.float32), ('recall_at_1', tf.float32), ('recall_at_1_improvement', tf.float32), ('documents_explored', tf.float32), ) return mzcore.EnvironmentDescriptor( observation_space=observation_space, action_space=gym.spaces.Discrete(len(grammar.productions())), reward_range=mzcore.Range(min_possible_score, max_possible_score), value_range=mzcore.Range(min_possible_cumulative_score, max_possible_cumulative_score), pretraining_space=gym.spaces.Tuple([ observation_space, gym.spaces.Box(0, len(grammar.productions()), (), np.int32), # action gym.spaces.Box(0., 1., (), np.float32), # reward gym.spaces.Box(0., 1., (), np.float32), # value gym.spaces.Box(0., 1., (), np.float32), # mask ] * common_flags.PRETRAINING_NUM_UNROLL_STEPS.value), extras={ 'bert_config': bert_config, 'sequence_length': sequence_length, 'float_names': bert_state_lib.FLOAT_NAMES, 'type_vocabs': bert_state_lib.TYPE_VOCABS, 'num_float_features': len(bert_state_lib.FLOAT_NAMES), 'type_vocab_sizes': [len(v) for v in bert_state_lib.TYPE_VOCABS.values()], 'grammar': grammar, 'max_episode_length': max_episode_length + 5, # we never want the agent to finish the episode 'learner_stats': learner_stats, # Only set this if `learner` does not specify an already pre-trained # checkpoint. 'bert_init_ckpt': common_flags.BERT_INIT_CKPT.value if learner_flags.INIT_CHECKPOINT.value is None else None, 'action_encoder_hidden_size': common_flags.ACTION_ENCODER_HIDDEN_SIZE.value, 'tokenizer': tokenizer, 'grammar_config': grammar_config, 'pretraining_num_unroll_steps': common_flags.PRETRAINING_NUM_UNROLL_STEPS.value, })