def encode_inputs(inputs):
    encoded_features = []
    for feature_name in inputs:
        if feature_name in CATEGORICAL_FEATURE_NAMES:
            vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
            # Create a lookup to convert a string values to an integer indices.
            # Since we are not using a mask token, nor expecting any out of vocabulary
            # (oov) token, we set mask_token to None and num_oov_indices to 0.
            lookup = StringLookup(vocabulary=vocabulary,
                                  mask_token=None,
                                  num_oov_indices=0)
            # Convert the string input values into integer indices.
            value_index = lookup(inputs[feature_name])
            embedding_dims = int(math.sqrt(lookup.vocabulary_size()))
            # Create an embedding layer with the specified dimensions.
            embedding = layers.Embedding(input_dim=lookup.vocabulary_size(),
                                         output_dim=embedding_dims)
            # Convert the index values to embedding representations.
            encoded_feature = embedding(value_index)
        else:
            # Use the numerical features as-is.
            encoded_feature = inputs[feature_name]
            if inputs[feature_name].shape[-1] is None:
                encoded_feature = tf.expand_dims(encoded_feature, -1)

        encoded_features.append(encoded_feature)

    encoded_features = layers.concatenate(encoded_features)
    return encoded_features
Esempio n. 2
0
class WordEmbedding(layers.Layer):
    UNK_MARK = '[UNK]'
    REP_CHAR = '\uFFFD'

    def __init__(self, vocabulary, output_dim, normalize_unicode='NFKC', lower_case=False, zero_digits=False,
                 max_len=None, reserved_words=None, embed_type='dense_auto', adapt_cutoff=None, adapt_factor=4,
                 embeddings_initializer='uniform', **kwargs):
        super().__init__(**kwargs)
        self.input_spec = layers.InputSpec(min_ndim=1, max_ndim=2, dtype='string')

        if not isinstance(vocabulary, list) or not all(map(lambda x: isinstance(x, str), vocabulary)):
            raise ValueError('Expected "vocabulary" to be a list of strings')
        if len(vocabulary) != len(set(vocabulary)):
            raise ValueError('Expected "vocabulary" to contain unique values')
        self.vocabulary = vocabulary

        self.output_dim = output_dim
        self.normalize_unicode = normalize_unicode
        self.lower_case = lower_case
        self.zero_digits = zero_digits

        if max_len is not None and max_len < 3:
            raise ValueError('Expected "max_len" to be None or greater then 2')
        self.max_len = max_len

        if reserved_words and len(reserved_words) != len(set(reserved_words)):
            raise ValueError('Expected "reserved_words" to contain unique values')
        self.reserved_words = reserved_words

        if embed_type not in {'dense_auto', 'dense_cpu', 'adapt'}:
            raise ValueError('Expected "embed_type" to be one of "dense_auto", "dense_cpu" or "adapt"')
        self.embed_type = embed_type

        self.adapt_cutoff = adapt_cutoff
        self.adapt_factor = adapt_factor
        self.embeddings_initializer = initializers.get(embeddings_initializer)

        all_reserved_words = [] if reserved_words is None else [r for r in reserved_words if self.UNK_MARK != r]
        self._reserved_words = [self.UNK_MARK] + all_reserved_words

        miss_reserved_words = [m for m in self._reserved_words if m not in vocabulary]
        if miss_reserved_words:
            tf.get_logger().warning('Vocabulary missed some reserved_words values: {}. '
                                    'This may indicate an error in vocabulary estimation'.format(miss_reserved_words))

        clean_vocab = [w for w in vocabulary if w not in self._reserved_words]
        self._vocabulary = self._reserved_words + clean_vocab

    def vocab(self, word_counts, **kwargs):
        if not word_counts:
            raise ValueError('Can\'t estimate vocabulary with empty word counter')
        if not all(map(lambda k: isinstance(k, str), word_counts.keys())):
            raise ValueError('Expected all words to be strings')

        word_counts = Vocabulary(word_counts)
        word_tokens = word_counts.tokens()
        adapt_words = self.adapt(word_tokens)
        if 1 == adapt_words.shape.rank:
            adapt_words = adapt_words[..., None]

        adapt_counts = Vocabulary()
        for adapts, word in zip(adapt_words, word_tokens):
            adapts = np.char.decode(adapts.numpy().reshape([-1]).astype('S'), 'utf-8')
            for adapt in adapts:
                adapt_counts[adapt] += word_counts[word]

        return adapt_counts

    @tf_utils.shape_type_conversion
    def build(self, input_shape=None):
        self.squeeze = False
        if 2 == len(input_shape):
            if 1 != input_shape[-1]:
                raise ValueError(
                    'Input 0 of layer {} is incompatible with the layer: if ndim=2 expected axis[-1]=1, found '
                    'axis[-1]={}. Full shape received: {}'.format(self.name, input_shape[-1], input_shape))

            self.squeeze = True
            input_shape = input_shape[:1]

        self.lookup = StringLookup(vocabulary=self._vocabulary, mask_token=None, oov_token=self.UNK_MARK)
        self.lookup.build(input_shape)

        if 'adapt' == self.embed_type:
            self.embed = AdaptiveEmbedding(
                self.adapt_cutoff, self.lookup.vocabulary_size(), self.output_dim, factor=self.adapt_factor,
                embeddings_initializer=self.embeddings_initializer)
        else:
            self.embed = layers.Embedding(
                self.lookup.vocabulary_size(), self.output_dim, embeddings_initializer=self.embeddings_initializer)
            if 'dense_auto' == self.embed_type:
                self.embed.build(input_shape)
            else:  # 'dense_cpu' == self.embed_type
                with tf.device('cpu:0'):
                    self.embed.build(input_shape)

        super().build(input_shape)

    def adapt(self, inputs):
        inputs = tf.convert_to_tensor(inputs, dtype='string')

        if self.normalize_unicode:
            inputs = miss_text.normalize_unicode(inputs, form=self.normalize_unicode, skip=self._reserved_words)
        if self.lower_case:
            inputs = miss_text.lower_case(inputs, skip=self._reserved_words)
        if self.zero_digits:
            inputs = miss_text.zero_digits(inputs, skip=self._reserved_words)

        if self.max_len is not None:
            inputs_ = tf.stack([
                miss_text.sub_string(inputs, 0, self.max_len // 2, skip=self._reserved_words),
                tf.fill(tf.shape(inputs), self.REP_CHAR),
                miss_text.sub_string(inputs, -self.max_len // 2 + 1, -1, skip=self._reserved_words)],
                axis=-1)
            inputs_ = tf.strings.reduce_join(inputs_, axis=-1)
            sizes = tf.strings.length(inputs, unit='UTF8_CHAR')
            inputs = tf.where(sizes > self.max_len, inputs_, inputs)

        return inputs

    def call(self, inputs, **kwargs):
        if self.squeeze:
            # Workaround for Sequential model test
            inputs = tf.squeeze(inputs, axis=-1)

        adapts = self.adapt(inputs)
        indices = self.lookup(adapts)
        outputs = self.embed(indices)

        return outputs

    @tf_utils.shape_type_conversion
    def compute_output_shape(self, input_shape):
        return input_shape + (self.output_dim,)

    def get_config(self):
        config = super().get_config()
        config.update({
            'vocabulary': self.vocabulary,
            'output_dim': self.output_dim,
            'normalize_unicode': self.normalize_unicode,
            'lower_case': self.lower_case,
            'zero_digits': self.zero_digits,
            'max_len': self.max_len,
            'reserved_words': self.reserved_words,
            'embed_type': self.embed_type,
            'adapt_cutoff': self.adapt_cutoff,
            'adapt_factor': self.adapt_factor,
            'embeddings_initializer': initializers.serialize(self.embeddings_initializer)
        })

        return config