コード例 #1
0
            )
        super(EmbeddingSim, self).build(input_shape)

    def compute_output_shape(self, input_shape):
        feature_shape, embed_shape = input_shape
        token_num = embed_shape[0]
        return feature_shape[:-1] + (token_num,)

    def compute_mask(self, inputs, mask=None):
        if mask is None:
            return None
        return mask[0]

    def call(self, inputs, mask=None, **kwargs):
        inputs, embeddings = inputs
        if self.stop_gradient:
            embeddings = K.stop_gradient(embeddings)
        outputs = K.dot(inputs, K.transpose(embeddings))
        if self.use_bias:
            outputs = K.bias_add(outputs, self.bias)
        if self.return_logits:
            return outputs
        return keras.activations.softmax(outputs)

utils.get_custom_objects().update(
    {
        'EmbeddingRet': EmbeddingRet,
        'EmbeddingSim': EmbeddingSim,
    }
)
コード例 #2
0
from .multi_head import MultiHead
from .multi_head_attention import MultiHeadAttention
from transformer_contrib.backend import utils

utils.get_custom_objects().update({
    'MultiHead': MultiHead,
    'MultiHeadAttention': MultiHeadAttention
})
コード例 #3
0
    """Restore mask from the second tensor.

    # Input shape
        Tensor with shape: `(batch_size, ...)`.
        Tensor with mask and shape: `(batch_size, ...)`.

    # Output shape
        Tensor with shape: `(batch_size, ...)`.
    """

    def __init__(self, **kwargs):
        super(RestoreMask, self).__init__(**kwargs)
        self.supports_masking = True

    def compute_output_shape(self, input_shape):
        return input_shape[0]

    def compute_mask(self, inputs, mask=None):
        return mask[1]

    def call(self, inputs, **kwargs):
        return K.identity(inputs[0])

utils.get_custom_objects().update(
    {
        'CreateMask': CreateMask,
        'RemoveMask': RemoveMask,
        'RestoreMask': RestoreMask,
        }
    )
コード例 #4
0
from .pos_embd import PositionEmbedding
from .trig_pos_embd import TrigPosEmbedding
from transformer_contrib.backend import utils

utils.get_custom_objects().update({
    'PositionEmbedding': PositionEmbedding,
    'TrigPosEmbedding': TrigPosEmbedding,
})
コード例 #5
0
from .feed_forward import FeedForward
from transformer_contrib.backend import utils

utils.get_custom_objects().update(
    {
        'FeedForward' : FeedForward,
        }
    )
コード例 #6
0
            self.gamma = self.add_weight(
                shape=shape,
                initializer=self.gamma_initializer,
                regularizer=self.gamma_regularizer,
                constraint=self.gamma_constraint,
                name='gamma',
            )
        if self.center:
            self.beta = self.add_weight(
                shape=shape,
                initializer=self.beta_initializer,
                regularizer=self.beta_regularizer,
                constraint=self.beta_constraint,
                name='beta',
            )
        super(LayerNormalization, self).build(input_shape)

    def call(self, inputs, training=None):
        mean = K.mean(inputs, axis=-1, keepdims=True)
        variance = K.mean(K.square(inputs - mean), axis=-1, keepdims=True)
        std = K.sqrt(variance + self.epsilon)
        outputs = (inputs - mean) / std
        if self.scale:
            outputs *= self.gamma
        if self.center:
            outputs += self.beta
        return outputs


utils.get_custom_objects().update({'LayerNormalization': LayerNormalization})
コード例 #7
0
from .embedding import *
from .softmax import *
from transformer_contrib.backend import utils

utils.get_custom_objects().update(
    {
        'AdaptiveEmbedding' : AdaptiveEmbedding,
        'AdaptiveSoftmax' : AdaptiveSoftmax
        }
    )
コード例 #8
0
from .seq_self_attention import SeqSelfAttention
from .seq_weighted_attention import SeqWeightedAttention
from .scaled_dot_attention import ScaledDotProductAttention
from transformer_contrib.backend import utils

utils.get_custom_objects().update({
    'SeqSelfAttention':
    SeqSelfAttention,
    'SeqWeightedAttention':
    SeqWeightedAttention,
    'ScaledDotProductAttention':
    ScaledDotProductAttention,
})
コード例 #9
0
from .gelu import gelu
from .transformer import *
from transformer_contrib.backend import utils

utils.get_custom_objects().update({'gelu': gelu})