def get_generator_config(config: configure_pretraining.PretrainingConfig, bert_config: modeling.BertConfig): """Get model config for the generator network.""" gen_config = modeling.BertConfig.from_dict(bert_config.to_dict()) gen_config.hidden_size = int( round(bert_config.hidden_size * config.generator_hidden_size)) gen_config.num_hidden_layers = int( round(bert_config.num_hidden_layers * config.generator_layers)) gen_config.intermediate_size = 4 * gen_config.hidden_size gen_config.num_attention_heads = max(1, gen_config.hidden_size // 64) return gen_config
def get_autoencoder_config(config: configure_pretraining.PretrainingConfig, bert_config: modeling.BertConfig): """Get model config for the autoencoder network.""" ae_config = modeling.BertConfig.from_dict(bert_config.to_dict()) ae_config.hidden_size = int( round(bert_config.hidden_size * config.autoencoder_hidden_size)) ae_config.num_hidden_layers = int( round(bert_config.num_hidden_layers * config.autoencoder_layers)) ae_config.intermediate_size = 4 * ae_config.hidden_size ae_config.num_attention_heads = max(1, ae_config.hidden_size // 64) return ae_config