Exemplo n.º 1
0
def get_descriptor():
    if FLAGS.env_compact:
        observation_space = gym.spaces.Box(
            -1, 1, (FLAGS.env_board_size, FLAGS.env_board_size, 1), np.float32)
    else:
        observation_space = gym.spaces.Box(
            0, 1, (FLAGS.env_board_size, FLAGS.env_board_size, 3), np.float32)

    return mzcore.EnvironmentDescriptor(
        observation_space=observation_space,
        action_space=gym.spaces.Discrete(FLAGS.env_board_size**2),
        reward_range=mzcore.Range(-2., 1.),
        value_range=mzcore.Range(-2., 1.),
    )
Exemplo n.º 2
0
def get_descriptor() -> mzcore.EnvironmentDescriptor:
    sequence_length = common_flags.BERT_SEQUENCE_LENGTH.value
    bert_config = configs.BertConfig.from_json_file(
        common_flags.BERT_CONFIG.value)

    grammar_config = grammar_lib.grammar_config_from_flags()
    max_len_type_vocab = max(map(len, bert_state_lib.TYPE_VOCABS.values()))
    tokenizer = bert_state_lib.get_tokenizer()
    grammar = grammar_lib.construct_grammar(grammar_config=grammar_config,
                                            vocab=list(
                                                tokenizer.get_vocab().keys()))

    observation_space = gym.spaces.Tuple([
        gym.spaces.Box(0, bert_config.vocab_size, (sequence_length, ),
                       np.int32),
        gym.spaces.Box(0, max_len_type_vocab,
                       (sequence_length, len(bert_state_lib.TYPE_VOCABS)),
                       np.int32),
        gym.spaces.Box(-np.inf, np.inf,
                       (sequence_length, len(bert_state_lib.FLOAT_NAMES)),
                       np.float32),
        gym.spaces.Box(0,
                       len(grammar.productions()) + 1,
                       (common_flags.N_ACTIONS_ENCODING.value, ),
                       np.int32),  # +1 for mask
    ])

    max_episode_length = common_flags.MAX_NUM_ACTIONS.value
    # If you change rewards / add new rewards, make sure to update the bounds.
    min_possible_score, max_possible_score = {
        'curiosity+dcg': (-1, 1),
    }.get(common_flags.REWARD.value, (-1, 1))
    min_possible_cumulative_score, max_possible_cumulative_score = {
        'curiosity+dcg': (-2, 2),
    }.get(common_flags.REWARD.value, (min_possible_score, max_possible_score))

    logging.info('Max episode length: %d; Score range: [%.2f, %.2f]',
                 max_episode_length, min_possible_score, max_possible_score)

    # Additional statistics that we want to track 'in' the learner
    learner_stats = (
        ('ndcg_score', tf.float32),
        ('ndcg_score_improvement', tf.float32),
        ('em_at_1', tf.float32),
        ('em_at_1_improvement', tf.float32),
        (f'em_at_{int(common_flags.K.value)}', tf.float32),
        (f'em_at_{int(common_flags.K.value)}_improvement', tf.float32),
        (f'recall_at_{int(common_flags.K.value)}', tf.float32),
        (f'recall_at_{int(common_flags.K.value)}_improvement', tf.float32),
        ('recall_at_1', tf.float32),
        ('recall_at_1_improvement', tf.float32),
        ('documents_explored', tf.float32),
    )

    return mzcore.EnvironmentDescriptor(
        observation_space=observation_space,
        action_space=gym.spaces.Discrete(len(grammar.productions())),
        reward_range=mzcore.Range(min_possible_score, max_possible_score),
        value_range=mzcore.Range(min_possible_cumulative_score,
                                 max_possible_cumulative_score),
        pretraining_space=gym.spaces.Tuple([
            observation_space,
            gym.spaces.Box(0, len(grammar.productions()),
                           (), np.int32),  # action
            gym.spaces.Box(0., 1., (), np.float32),  # reward
            gym.spaces.Box(0., 1., (), np.float32),  # value
            gym.spaces.Box(0., 1., (), np.float32),  # mask
        ] * common_flags.PRETRAINING_NUM_UNROLL_STEPS.value),
        extras={
            'bert_config':
            bert_config,
            'sequence_length':
            sequence_length,
            'float_names':
            bert_state_lib.FLOAT_NAMES,
            'type_vocabs':
            bert_state_lib.TYPE_VOCABS,
            'num_float_features':
            len(bert_state_lib.FLOAT_NAMES),
            'type_vocab_sizes':
            [len(v) for v in bert_state_lib.TYPE_VOCABS.values()],
            'grammar':
            grammar,
            'max_episode_length':
            max_episode_length +
            5,  # we never want the agent to finish the episode
            'learner_stats':
            learner_stats,
            # Only set this if `learner` does not specify an already pre-trained
            # checkpoint.
            'bert_init_ckpt':
            common_flags.BERT_INIT_CKPT.value
            if learner_flags.INIT_CHECKPOINT.value is None else None,
            'action_encoder_hidden_size':
            common_flags.ACTION_ENCODER_HIDDEN_SIZE.value,
            'tokenizer':
            tokenizer,
            'grammar_config':
            grammar_config,
            'pretraining_num_unroll_steps':
            common_flags.PRETRAINING_NUM_UNROLL_STEPS.value,
        })