Beispiel #1
0
	def build_embedder(self, input_ids, token_type_ids, 
									hidden_dropout_prob, 
									attention_probs_dropout_prob,
									use_bfloat16,
									is_training,
									use_tpu,
									**kargs):

		embedding_table_adv = kargs.get('embedding_table_adv', None)
		print(embedding_table_adv, "==embedding-adv")

		embedding_seq_adv = kargs.get('embedding_seq_adv', None)
		print(embedding_seq_adv, "==embedding-adv")

		emb_adv_pos = kargs.get("emb_adv_pos", "emb_adv_post")
		stop_gradient = kargs.get("stop_gradient", False)

		if self.config.get('embedding_scope', None):
			embedding_scope = self.config['embedding_scope']
			tf.logging.info("==using embedding scope of original model_config.embedding_scope: %s ==", embedding_scope)
		else:
			embedding_scope = self.config.get("scope", "model")
			tf.logging.info("==using embedding scope of original model_config.embedding_scope: %s ==", embedding_scope)

		initializer = get_initializer(self.config)

		dtype = tf.float32 if not use_bfloat16 else tf.bfloat16
		with tf.variable_scope(embedding_scope, reuse=tf.AUTO_REUSE):
			embed_name = os.path.join(embedding_scope, 'embed')
			[self.input_embed, 
			self.word_embed_table, 
			self.emb_dict] = funnel_transformer_modules.input_embedding(
				self.config,
				initializer,
				input_ids, 
				is_training, 
				seg_id=token_type_ids, 
				use_tpu=use_tpu, 
				dtype=dtype,
				embedding_table_adv=embedding_table_adv,
				embedding_seq_adv=embedding_seq_adv,
				emb_adv_pos=emb_adv_pos,
				stop_gradient=stop_gradient,
				name=embed_name)
			funnel_transformer_ops.update_ret_dict(self.ret_dict, self.emb_dict, "emb")
def tfmxl_layer(net_config, q, k, v, pos_enc, seg_mat, attn_mask, 
								is_training,
								initializer,
								func_mask=None, attn_bias=None,
								name="tfmxl"):

	"""Single transformer-xl layer."""
	net_config = net_config

	ret_dict = {}
	output, attn_dict = funnel_transformer_ops.rel_multihead_attn(
			net_config=net_config,
			q=q,
			k=k,
			v=v,
			pos_enc=pos_enc,
			seg_mat=seg_mat,
			attn_mask=attn_mask,
			attn_bias=attn_bias,
			d_model=net_config.d_model,
			n_head=net_config.n_head,
			d_head=net_config.d_head,
			dropout=net_config.dropout,
			dropatt=net_config.dropatt,
			is_training=is_training,
			initializer=initializer,
			func_mask=func_mask,
			rel_attn_type=net_config.rel_attn_type,
			name=name)

	output, pffn_dict = funnel_transformer_ops.positionwise_ffn(
			inp=output,
			d_model=net_config.d_model,
			d_inner=net_config.d_inner,
			activation_type=net_config.ff_activation,
			dropout=net_config.dropout,
			dropact=net_config.dropact,
			is_training=is_training,
			initializer=initializer,
			name=name)

	funnel_transformer_ops.update_ret_dict(ret_dict, attn_dict, "attn")
	funnel_transformer_ops.update_ret_dict(ret_dict, pffn_dict, "pffn")
	return output, ret_dict
Beispiel #3
0
	def build_decoder(self, hiddens, input_ids, input_mask, 
									token_type_ids, is_training, **kargs):
		# decoder
		if self.config.n_block > 1:
			initializer = get_initializer(self.config)
			scope = self.config.get("scope", "model")
			with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
				[self.decoder_output, 
				self.dec_dict] = funnel_transformer_modules.decoder(
						self.config,
						hiddens,
						input_mask=input_mask,
						seg_id=token_type_ids,
						is_training=is_training,
						initializer=initializer,
						attn_structures=self.attn_structures)
			funnel_transformer_ops.update_ret_dict(self.ret_dict, 
																						self.dec_dict, 
																						"dec")
		else:
			self.decoder_output = None
			self.dec_dict = {}
Beispiel #4
0
	def build_encoder(self, input_ids, input_mask, 
									token_type_ids,
									hidden_dropout_prob, 
									attention_probs_dropout_prob,
									is_training,
									use_bfloat16,
									embedding_output=None,
									**kargs):

		initializer = get_initializer(self.config)
		dtype = tf.float32 if not use_bfloat16 else tf.bfloat16
		scope = self.config.get("scope", "model")
		self.attn_structures = None

		if embedding_output is not None:
			embedding_seq_output = embedding_output
			tf.logging.info("****** outer-embedding_seq_output *******")
		else:
			embedding_seq_output = self.input_embed
			tf.logging.info("****** self-embedding_seq_output *******")

		with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
			[self.encoder_output, 
				self.encoder_hiddens, 
				self.enc_dict,
				self.attn_structures] = funnel_transformer_modules.encoder(
					self.config,
					embedding_seq_output,
					is_training,
					initializer,
					seg_id=token_type_ids,
					input_mask=input_mask,
					attn_structures=self.attn_structures)
			print(self.attn_structures, "==attention structures==")

		funnel_transformer_ops.update_ret_dict(self.ret_dict, 
																					self.enc_dict, 
																					"enc")
def decoder(net_config,
            hiddens,
            is_training,
            initializer,
            input_mask=None,
            seg_id=None,
            pos_id=None,
            scope="decoder",
            reuse=tf.AUTO_REUSE,
            attn_structures=None,
            **kargs):
    """Decode a compressed sequence into a full sequence."""
    net_config = net_config
    ret_dict = {}

    output, bridge_dict = funnel_transformer_utils.bridge_layer(net_config,
                                                                hiddens,
                                                                input_mask,
                                                                reuse=reuse)
    funnel_transformer_ops.update_ret_dict(ret_dict, bridge_dict, "bridge")

    if net_config.decoder_depth == 0:
        return output, ret_dict

    # prepare structures for relative attention
    attn_structures_name = os.path.join(scope, 'attn_structures')
    # pos_enc, seg_mat, func_mask = funnel_transformer_utils.init_attn_structures(
    # 		output, seg_id, pos_id, is_training)

    pos_enc, seg_mat, func_mask = funnel_transformer_utils.init_attn_structures(
        net_config, attn_structures, output, seg_id, pos_id, is_training,
        attn_structures_name)
    attn_mask = None if input_mask is None else input_mask[:, None, None]

    # Decoder layers
    n_enc_param_layer = sum(net_config.block_param_size)
    with tf.variable_scope(scope, reuse=reuse):
        for param_idx in range(net_config.decoder_param_size):
            layer_idx = n_enc_param_layer + param_idx
            with tf.variable_scope("layer_{}".format(layer_idx), reuse=reuse):
                for repeat_idx in range(net_config.decoder_repeat_size):
                    tfmxl_name = os.path.join(scope, str(layer_idx),
                                              str(repeat_idx), 'tfmxl_layer')
                    output, layer_dict = funnel_transformer_utils.tfmxl_layer(
                        net_config=net_config,
                        q=output,
                        k=output,
                        v=output,
                        pos_enc=pos_enc,
                        seg_mat=seg_mat,
                        attn_mask=attn_mask,
                        is_training=is_training,
                        initializer=initializer,
                        func_mask=func_mask,
                        name=tfmxl_name)

                    funnel_transformer_ops.update_ret_dict(
                        ret_dict, layer_dict,
                        "layer_{}/repeat_{}".format(layer_idx, repeat_idx))

    return output, ret_dict
def encoder(net_config,
            input_embed,
            is_training,
            initializer,
            seg_id=None,
            pos_id=None,
            input_mask=None,
            scope="encoder",
            reuse=tf.AUTO_REUSE,
            seq_type=None,
            mask_type=None,
            attn_structures=None,
            **kargs):
    """Encoder of the Funnel-Transformer."""
    net_config = net_config
    ret_dict = {}

    with tf.variable_scope(scope, reuse=reuse):

        ##### Input projection
        output, _ = funnel_transformer_utils.input_projection(
            net_config, input_embed, initializer)

        ##### Encoder layers
        hiddens = []
        layer_dict = {}
        for block_idx in range(net_config.n_block):
            # prepare structures for relative attention
            if block_idx == 0:
                attn_structures_name = os.path.join(scope, str(block_idx),
                                                    'attn_structures')
                pos_enc, seg_mat, func_mask = funnel_transformer_utils.init_attn_structures(
                    net_config, attn_structures, output, seg_id, pos_id,
                    is_training, attn_structures_name)
                if attn_structures is None:
                    attn_structures = (pos_enc, seg_mat, func_mask)
            else:
                pre_attn_pooling_name = os.path.join(scope, str(block_idx),
                                                     'pre_attn_pooling')
                pool_ret = funnel_transformer_utils.pre_attn_pooling(
                    net_config, output, pos_enc, seg_mat, input_mask,
                    func_mask, block_idx, is_training, pre_attn_pooling_name)

                pooled_out, pos_enc, seg_mat, input_mask, func_mask = pool_ret
            attn_mask = None if input_mask is None else input_mask[:, None,
                                                                   None]

            for param_idx in range(net_config.block_param_size[block_idx]):
                ##### current layer idx
                layer_idx = sum(
                    net_config.block_param_size[:block_idx]) + param_idx
                with tf.variable_scope("layer_{}".format(layer_idx),
                                       reuse=reuse):
                    cur_repeat_size = net_config.block_repeat_size[block_idx]
                    for repeat_idx in range(cur_repeat_size):
                        sub_idx = (param_idx * cur_repeat_size + repeat_idx)
                        do_pooling = block_idx > 0 and sub_idx == 0
                        print(do_pooling, "===do ppoling===")
                        # prepare inputs to the current layer
                        if do_pooling:
                            if net_config.pool_q_only:
                                q = pooled_out
                                k = v = output
                            else:
                                q = k = v = pooled_out
                        else:
                            q = k = v = output

                        # attention layer
                        tfmxl_name = os.path.join(scope, str(block_idx),
                                                  str(layer_idx),
                                                  str(repeat_idx),
                                                  'tfmxl_layer')
                        output, layer_dict = funnel_transformer_utils.tfmxl_layer(
                            net_config=net_config,
                            q=q,
                            k=k,
                            v=v,
                            pos_enc=pos_enc,
                            seg_mat=seg_mat,
                            attn_mask=attn_mask,
                            is_training=is_training,
                            initializer=initializer,
                            func_mask=func_mask,
                            name=tfmxl_name)

                        # post-attention pooling
                        if do_pooling:
                            post_attn_pooling_name = os.path.join(
                                scope, str(block_idx), str(layer_idx),
                                str(repeat_idx), 'post_attn_pooling')
                            pool_ret = funnel_transformer_utils.post_attn_pooling(
                                net_config, pos_enc, seg_mat, input_mask,
                                func_mask, block_idx, is_training,
                                post_attn_pooling_name)
                            pos_enc, seg_mat, input_mask, func_mask = pool_ret
                            attn_mask = None if input_mask is None \
                              else input_mask[:, None, None]

                        # update ret dict
                        hiddens.append(output)
                        prefix = "block_{}/layer_{}/repeat_{}".format(
                            block_idx, layer_idx, repeat_idx)
                        funnel_transformer_ops.update_ret_dict(
                            ret_dict, layer_dict, prefix)

    return output, hiddens, ret_dict, attn_structures