def test_step(self, inputs):
     self.shape_checker = ShapeChecker()
     # return self._test_step(inputs)
     if self.enable_eager_execution:
         return self._test_step(inputs)
     else:
         return self._tf_test_step(inputs)
Пример #2
0
    def __call__(self, y_true, y_pred):
        shape_checker = ShapeChecker()
        if self.sequence:
            shape_checker(y_true, ('batch', 't'))
            shape_checker(y_pred, ('batch', 't', 'logits'))
        else:
            shape_checker(y_true, ('batch', ))
            shape_checker(y_pred, ('batch', 'logits'))

        # Calculate the loss for each item in the batch.
        loss = self.loss(y_true, y_pred)
        if self.sequence:
            shape_checker(loss, ('batch', 't'))
        else:
            shape_checker(loss, ('batch', ))

        # Mask off the losses on padding.
        mask = tf.cast(y_true != 0, tf.float32)
        if self.sequence:
            shape_checker(mask, ('batch', 't'))
        else:
            shape_checker(mask, ('batch', ))

        loss *= mask

        # Return the total.
        return tf.reduce_sum(loss)
Пример #3
0
    def call(self, query, value, mask):
        shape_checker = ShapeChecker()
        shape_checker(query, ('batch', 't', 'query_units'))
        shape_checker(value, ('batch', 's', 'value_units'))
        shape_checker(mask, ('batch', 's'))

        # From Eqn. (4), `W1@ht`.
        w1_query = self.W1(query)
        shape_checker(w1_query, ('batch', 't', 'attn_units'))

        # From Eqn. (4), `W2@hs`.
        w2_key = self.W2(value)
        shape_checker(w2_key, ('batch', 's', 'attn_units'))

        query_mask = tf.ones(tf.shape(query)[:-1], dtype=bool)
        value_mask = mask

        context_vector, attention_weights = self.attention(
            inputs=[w1_query, value, w2_key],
            mask=[query_mask, value_mask],
            return_attention_scores=True,
        )
        shape_checker(context_vector, ('batch', 't', 'value_units'))
        shape_checker(attention_weights, ('batch', 't', 's'))

        return context_vector, attention_weights
Пример #4
0
    def call(self, tokens, state=None):
        shape_checker = ShapeChecker()
        shape_checker(tokens, ('batch', 's'))

        # 2. The embedding layer looks up the embedding for each token.
        vectors = self.embedding(tokens)
        shape_checker(vectors, ('batch', 's', 'embed_dim'))

        # 3. The GRU processes the embedding sequence.
        #    output shape: (batch, s, enc_units)
        #    state shape: (batch, enc_units)
        output, state = self.gru(vectors, initial_state=state)
        shape_checker(output, ('batch', 's', 'enc_units'))
        shape_checker(state, ('batch', 'enc_units'))

        # 4. Returns the new sequence and its state.
        return output, state
Пример #5
0
    def __init__(self, embedding_dim, units,
                 input_text_processor,
                 output_text_processor,
                 use_tf_function=True):
        super().__init__()
        # Build the encoder and decoder
        encoder = Encoder(input_text_processor.vocabulary_size(),
                          embedding_dim, units)
        decoder = Decoder(output_text_processor.vocabulary_size(),
                          embedding_dim, units)

        self.encoder = encoder
        self.decoder = decoder
        self.input_text_processor = input_text_processor
        self.output_text_processor = output_text_processor
        self.use_tf_function = use_tf_function
        self.shape_checker = ShapeChecker()
    def __init__(
            self,
            embedding_matrix,
            enc_units,
            dec_units,
            tokenizer,
            rnn_type='GRU',
            enable_eager_execution=False,
            use_nearest_token_embedding=Config['use_nearest_token_embedding'],
            train_embeddings=False,
            enc_return_seq=False,
            dec_return_seq=False):
        super(AutoEncoder, self).__init__()
        """ Set pre-defined tokenizer """
        self.tokenizer = tokenizer
        self.vocab_size = tokenizer.vocabulary_size()
        self.enable_eager_execution = enable_eager_execution
        self.use_nearest_token_embedding = use_nearest_token_embedding
        if self.use_nearest_token_embedding:
            self.norm_embedding_matrix_transposed = tf.transpose(
                tf.math.l2_normalize(embedding_matrix, axis=-1), perm=[1, 0])
        else:
            self.norm_embedding_matrix_transposed = None

        self.embedding_layer = tf.keras.layers.Embedding(
            embedding_matrix.shape[0],
            embedding_matrix.shape[1],
            embeddings_initializer=tf.keras.initializers.Constant(
                embedding_matrix),
            trainable=train_embeddings,
        )
        """ Initialize Encoder & Decoder """
        self.encoder = Encoder(self.embedding_layer,
                               enc_units,
                               recurrent_layer_type=rnn_type,
                               return_seq=enc_return_seq)
        self.decoder = Decoder(self.embedding_layer,
                               dec_units,
                               self.vocab_size,
                               recurrent_layer_type=rnn_type,
                               return_seq=dec_return_seq)

        self.shape_checker = ShapeChecker()
Пример #7
0
    def call(self,
             inputs: DecoderInput,
             state=None) -> Tuple[DecoderOutput, tf.Tensor]:
        shape_checker = ShapeChecker()
        shape_checker(inputs.new_tokens, ('batch', 't'))
        shape_checker(inputs.enc_output, ('batch', 's', 'enc_units'))
        shape_checker(inputs.mask, ('batch', 's'))

        if state is not None:
            shape_checker(state, ('batch', 'dec_units'))

        # Step 1. Lookup the embeddings
        vectors = self.embedding(inputs.new_tokens)
        shape_checker(vectors, ('batch', 't', 'embedding_dim'))

        # Step 2. Process one step with the RNN
        rnn_output, state = self.gru(vectors, initial_state=state)

        shape_checker(rnn_output, ('batch', 't', 'dec_units'))
        shape_checker(state, ('batch', 'dec_units'))

        # Step 3. Use the RNN output as the query for the attention over the
        # encoder output.
        context_vector, attention_weights = self.attention(
            query=rnn_output, value=inputs.enc_output, mask=inputs.mask)
        shape_checker(context_vector, ('batch', 't', 'dec_units'))
        shape_checker(attention_weights, ('batch', 't', 's'))

        # Step 4. Eqn. (3): Join the context_vector and rnn_output
        #     [ct; ht] shape: (batch t, value_units + query_units)
        context_and_rnn_output = tf.concat([context_vector, rnn_output], axis=-1)

        # Step 4. Eqn. (3): `at = tanh(Wc@[ct; ht])`
        attention_vector = self.Wc(context_and_rnn_output)
        shape_checker(attention_vector, ('batch', 't', 'dec_units'))

        # Step 5. Generate logit predictions:
        logits = self.fc(attention_vector)
        shape_checker(logits, ('batch', 't', 'output_vocab_size'))

        return DecoderOutput(logits, attention_weights), state
Пример #8
0
    def call(self, new_tokens, enc_output, state=None, mask=None):
        shape_checker = ShapeChecker()
        shape_checker(new_tokens, ('batch', 't'))
        shape_checker(enc_output, ('batch', 'enc_units'))

        # if mask is not None:
        #   shape_checker(mask, ('batch', 's'))

        # if state is not None:
        #   shape_checker(state, ('batch', 'dec_units'))

        # Step 1. Lookup the embeddings
        vectors = self.embedding(new_tokens)
        shape_checker(vectors, ('batch', 't', 'embedding_dim'))

        # Step 2. Process one step with the RNN
        if self.recurrent_layer_type == 'GRU':
            rnn_output, state = self.recurrent_layer(vectors,
                                                     initial_state=state)
        else:
            rnn_output, state_h, state_c = self.recurrent_layer(
                vectors, initial_state=state)
            state = [state_h, state_c]

        if self.return_seq:
            shape_checker(rnn_output, ('batch', 't', 'dec_units'))
        else:
            shape_checker(rnn_output, ('batch', 'dec_units'))
        # shape_checker(state, ('batch', 'dec_units'))

        # Step 3. Concatenate the encoder output and RNN output
        concat_output = tf.concat([rnn_output, enc_output], axis=-1)

        # Step 4. Pass through Dense Layer to generate logits for each token in vocab
        dec_output = self.fc(concat_output)
        shape_checker(dec_output, ('batch', 'embedding_dim'))

        return dec_output, state
Пример #9
0
    def call(self, tokens, state=None):
        shape_checker = ShapeChecker()
        shape_checker(tokens, ('batch', 's'))

        # 2. The embedding layer looks up the embedding for each token.
        vectors = self.embedding(tokens)
        shape_checker(vectors, ('batch', 's', 'embed_dim'))

        # 3. The RNN Layer processes the embedding sequence.
        #    output shape: (batch, s, enc_units) | (batch, enc_units)
        #    state shape: (batch, enc_units)
        if self.bidirectional:
            bi_output = self.recurrent_layer(vectors, initial_state=state)
            if self.recurrent_layer_type == 'GRU':
                # bi_output has 3 elements: output, state from forward direction, state from backward direction
                state = bi_output[1]
            else:
                # bi_output has 5 elements: output, state_h and state_c from forward direction, state_h and state_c from backward direction
                state = bi_output[1:3]
            output = self.dense_layer(bi_output[0])
        else:
            if self.recurrent_layer_type == 'GRU':
                output, state = self.recurrent_layer(vectors,
                                                     initial_state=state)
            else:
                #    state is a list for LSTM: [state_h, state_c]
                output, state_h, state_c = self.recurrent_layer(
                    vectors, initial_state=state)
                state = [state_h, state_c]

        # if self.return_seq:
        #   shape_checker(output, ('batch', 's', 'enc_units'))
        # else:
        #   shape_checker(output, ('batch', 'enc_units'))
        # shape_checker(state, ('batch', 'enc_units'))

        # 4. Returns the output and state.
        return output, state
Пример #10
0
 def train_step(self, inputs):
     self.shape_checker = ShapeChecker()
     if self.use_tf_function:
         return self._tf_train_step(inputs)
     else:
         return self._train_step(inputs)