def build_embedder(self, input_ids, token_type_ids, hidden_dropout_prob, attention_probs_dropout_prob, use_bfloat16, is_training, use_tpu, **kargs): embedding_table_adv = kargs.get('embedding_table_adv', None) print(embedding_table_adv, "==embedding-adv") embedding_seq_adv = kargs.get('embedding_seq_adv', None) print(embedding_seq_adv, "==embedding-adv") emb_adv_pos = kargs.get("emb_adv_pos", "emb_adv_post") stop_gradient = kargs.get("stop_gradient", False) if self.config.get('embedding_scope', None): embedding_scope = self.config['embedding_scope'] tf.logging.info("==using embedding scope of original model_config.embedding_scope: %s ==", embedding_scope) else: embedding_scope = self.config.get("scope", "model") tf.logging.info("==using embedding scope of original model_config.embedding_scope: %s ==", embedding_scope) initializer = get_initializer(self.config) dtype = tf.float32 if not use_bfloat16 else tf.bfloat16 with tf.variable_scope(embedding_scope, reuse=tf.AUTO_REUSE): embed_name = os.path.join(embedding_scope, 'embed') [self.input_embed, self.word_embed_table, self.emb_dict] = funnel_transformer_modules.input_embedding( self.config, initializer, input_ids, is_training, seg_id=token_type_ids, use_tpu=use_tpu, dtype=dtype, embedding_table_adv=embedding_table_adv, embedding_seq_adv=embedding_seq_adv, emb_adv_pos=emb_adv_pos, stop_gradient=stop_gradient, name=embed_name) funnel_transformer_ops.update_ret_dict(self.ret_dict, self.emb_dict, "emb")
def tfmxl_layer(net_config, q, k, v, pos_enc, seg_mat, attn_mask, is_training, initializer, func_mask=None, attn_bias=None, name="tfmxl"): """Single transformer-xl layer.""" net_config = net_config ret_dict = {} output, attn_dict = funnel_transformer_ops.rel_multihead_attn( net_config=net_config, q=q, k=k, v=v, pos_enc=pos_enc, seg_mat=seg_mat, attn_mask=attn_mask, attn_bias=attn_bias, d_model=net_config.d_model, n_head=net_config.n_head, d_head=net_config.d_head, dropout=net_config.dropout, dropatt=net_config.dropatt, is_training=is_training, initializer=initializer, func_mask=func_mask, rel_attn_type=net_config.rel_attn_type, name=name) output, pffn_dict = funnel_transformer_ops.positionwise_ffn( inp=output, d_model=net_config.d_model, d_inner=net_config.d_inner, activation_type=net_config.ff_activation, dropout=net_config.dropout, dropact=net_config.dropact, is_training=is_training, initializer=initializer, name=name) funnel_transformer_ops.update_ret_dict(ret_dict, attn_dict, "attn") funnel_transformer_ops.update_ret_dict(ret_dict, pffn_dict, "pffn") return output, ret_dict
def build_decoder(self, hiddens, input_ids, input_mask, token_type_ids, is_training, **kargs): # decoder if self.config.n_block > 1: initializer = get_initializer(self.config) scope = self.config.get("scope", "model") with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): [self.decoder_output, self.dec_dict] = funnel_transformer_modules.decoder( self.config, hiddens, input_mask=input_mask, seg_id=token_type_ids, is_training=is_training, initializer=initializer, attn_structures=self.attn_structures) funnel_transformer_ops.update_ret_dict(self.ret_dict, self.dec_dict, "dec") else: self.decoder_output = None self.dec_dict = {}
def build_encoder(self, input_ids, input_mask, token_type_ids, hidden_dropout_prob, attention_probs_dropout_prob, is_training, use_bfloat16, embedding_output=None, **kargs): initializer = get_initializer(self.config) dtype = tf.float32 if not use_bfloat16 else tf.bfloat16 scope = self.config.get("scope", "model") self.attn_structures = None if embedding_output is not None: embedding_seq_output = embedding_output tf.logging.info("****** outer-embedding_seq_output *******") else: embedding_seq_output = self.input_embed tf.logging.info("****** self-embedding_seq_output *******") with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): [self.encoder_output, self.encoder_hiddens, self.enc_dict, self.attn_structures] = funnel_transformer_modules.encoder( self.config, embedding_seq_output, is_training, initializer, seg_id=token_type_ids, input_mask=input_mask, attn_structures=self.attn_structures) print(self.attn_structures, "==attention structures==") funnel_transformer_ops.update_ret_dict(self.ret_dict, self.enc_dict, "enc")
def decoder(net_config, hiddens, is_training, initializer, input_mask=None, seg_id=None, pos_id=None, scope="decoder", reuse=tf.AUTO_REUSE, attn_structures=None, **kargs): """Decode a compressed sequence into a full sequence.""" net_config = net_config ret_dict = {} output, bridge_dict = funnel_transformer_utils.bridge_layer(net_config, hiddens, input_mask, reuse=reuse) funnel_transformer_ops.update_ret_dict(ret_dict, bridge_dict, "bridge") if net_config.decoder_depth == 0: return output, ret_dict # prepare structures for relative attention attn_structures_name = os.path.join(scope, 'attn_structures') # pos_enc, seg_mat, func_mask = funnel_transformer_utils.init_attn_structures( # output, seg_id, pos_id, is_training) pos_enc, seg_mat, func_mask = funnel_transformer_utils.init_attn_structures( net_config, attn_structures, output, seg_id, pos_id, is_training, attn_structures_name) attn_mask = None if input_mask is None else input_mask[:, None, None] # Decoder layers n_enc_param_layer = sum(net_config.block_param_size) with tf.variable_scope(scope, reuse=reuse): for param_idx in range(net_config.decoder_param_size): layer_idx = n_enc_param_layer + param_idx with tf.variable_scope("layer_{}".format(layer_idx), reuse=reuse): for repeat_idx in range(net_config.decoder_repeat_size): tfmxl_name = os.path.join(scope, str(layer_idx), str(repeat_idx), 'tfmxl_layer') output, layer_dict = funnel_transformer_utils.tfmxl_layer( net_config=net_config, q=output, k=output, v=output, pos_enc=pos_enc, seg_mat=seg_mat, attn_mask=attn_mask, is_training=is_training, initializer=initializer, func_mask=func_mask, name=tfmxl_name) funnel_transformer_ops.update_ret_dict( ret_dict, layer_dict, "layer_{}/repeat_{}".format(layer_idx, repeat_idx)) return output, ret_dict
def encoder(net_config, input_embed, is_training, initializer, seg_id=None, pos_id=None, input_mask=None, scope="encoder", reuse=tf.AUTO_REUSE, seq_type=None, mask_type=None, attn_structures=None, **kargs): """Encoder of the Funnel-Transformer.""" net_config = net_config ret_dict = {} with tf.variable_scope(scope, reuse=reuse): ##### Input projection output, _ = funnel_transformer_utils.input_projection( net_config, input_embed, initializer) ##### Encoder layers hiddens = [] layer_dict = {} for block_idx in range(net_config.n_block): # prepare structures for relative attention if block_idx == 0: attn_structures_name = os.path.join(scope, str(block_idx), 'attn_structures') pos_enc, seg_mat, func_mask = funnel_transformer_utils.init_attn_structures( net_config, attn_structures, output, seg_id, pos_id, is_training, attn_structures_name) if attn_structures is None: attn_structures = (pos_enc, seg_mat, func_mask) else: pre_attn_pooling_name = os.path.join(scope, str(block_idx), 'pre_attn_pooling') pool_ret = funnel_transformer_utils.pre_attn_pooling( net_config, output, pos_enc, seg_mat, input_mask, func_mask, block_idx, is_training, pre_attn_pooling_name) pooled_out, pos_enc, seg_mat, input_mask, func_mask = pool_ret attn_mask = None if input_mask is None else input_mask[:, None, None] for param_idx in range(net_config.block_param_size[block_idx]): ##### current layer idx layer_idx = sum( net_config.block_param_size[:block_idx]) + param_idx with tf.variable_scope("layer_{}".format(layer_idx), reuse=reuse): cur_repeat_size = net_config.block_repeat_size[block_idx] for repeat_idx in range(cur_repeat_size): sub_idx = (param_idx * cur_repeat_size + repeat_idx) do_pooling = block_idx > 0 and sub_idx == 0 print(do_pooling, "===do ppoling===") # prepare inputs to the current layer if do_pooling: if net_config.pool_q_only: q = pooled_out k = v = output else: q = k = v = pooled_out else: q = k = v = output # attention layer tfmxl_name = os.path.join(scope, str(block_idx), str(layer_idx), str(repeat_idx), 'tfmxl_layer') output, layer_dict = funnel_transformer_utils.tfmxl_layer( net_config=net_config, q=q, k=k, v=v, pos_enc=pos_enc, seg_mat=seg_mat, attn_mask=attn_mask, is_training=is_training, initializer=initializer, func_mask=func_mask, name=tfmxl_name) # post-attention pooling if do_pooling: post_attn_pooling_name = os.path.join( scope, str(block_idx), str(layer_idx), str(repeat_idx), 'post_attn_pooling') pool_ret = funnel_transformer_utils.post_attn_pooling( net_config, pos_enc, seg_mat, input_mask, func_mask, block_idx, is_training, post_attn_pooling_name) pos_enc, seg_mat, input_mask, func_mask = pool_ret attn_mask = None if input_mask is None \ else input_mask[:, None, None] # update ret dict hiddens.append(output) prefix = "block_{}/layer_{}/repeat_{}".format( block_idx, layer_idx, repeat_idx) funnel_transformer_ops.update_ret_dict( ret_dict, layer_dict, prefix) return output, hiddens, ret_dict, attn_structures