Exemple #1
0
    def get_initial_loop_state(self) -> LoopState:
        default_ls = AutoregressiveDecoder.get_initial_loop_state(self)
        feedables = default_ls.feedables._asdict()
        histories = default_ls.histories._asdict()

        feedables["prev_contexts"] = [
            tf.zeros([self.batch_size, a.context_vector_size])
            for a in self.attentions
        ]

        feedables["prev_rnn_state"] = self.initial_state
        feedables["prev_rnn_output"] = self.initial_state

        histories["attention_histories"] = [
            a.initial_loop_state() for a in self.attentions if a is not None
        ]

        histories["decoder_outputs"] = tf.zeros(
            shape=[0, self.batch_size, self.rnn_size],
            dtype=tf.float32,
            name="hist_decoder_outputs")

        # pylint: disable=not-callable
        rnn_feedables = RNNFeedables(**feedables)
        rnn_histories = RNNHistories(**histories)
        # pylint: enable=not-callable

        return LoopState(histories=rnn_histories,
                         constants=default_ls.constants,
                         feedables=rnn_feedables)
    def get_initial_loop_state(self) -> LoopState:

        default_ls = AutoregressiveDecoder.get_initial_loop_state(self)
        # feedables = default_ls.feedables._asdict()
        histories = default_ls.histories._asdict()

        histories["self_attention_histories"] = [
            empty_multi_head_loop_state(self.n_heads_self)
            for a in range(self.depth)]

        histories["inter_attention_histories"] = [
            empty_multi_head_loop_state(self.n_heads_enc)
            for a in range(self.depth)]

        histories["decoded_symbols"] = tf.TensorArray(
            dtype=tf.int32, dynamic_size=True, size=0,
            clear_after_read=False, name="decoded_symbols")

        input_mask = tf.TensorArray(
            dtype=tf.float32, dynamic_size=True, size=0,
            clear_after_read=False, name="input_mask")

        histories["input_mask"] = input_mask.write(
            0, tf.ones_like(self.go_symbols, dtype=tf.float32))

        # TransformerHistories is a type and should be callable
        # pylint: disable=not-callable
        tr_histories = TransformerHistories(**histories)
        # pylint: enable=not-callable

        return LoopState(
            histories=tr_histories,
            constants=[],
            feedables=default_ls.feedables)
Exemple #3
0
    def get_initial_loop_state(self) -> LoopState:

        default_ls = AutoregressiveDecoder.get_initial_loop_state(self)
        histories = default_ls.histories._asdict()

        #        histories["self_attention_histories"] = [
        #            empty_multi_head_loop_state(self.batch_size, self.n_heads_self)
        #            for a in range(self.depth)]

        #        histories["inter_attention_histories"] = [
        #            empty_multi_head_loop_state(self.batch_size, self.n_heads_enc)
        #            for a in range(self.depth)]

        histories["decoded_symbols"] = tf.zeros(shape=[0, self.batch_size],
                                                dtype=tf.int32,
                                                name="decoded_symbols")

        histories["input_mask"] = tf.zeros(shape=[0, self.batch_size],
                                           dtype=tf.float32,
                                           name="input_mask")

        # TransformerHistories is a type and should be callable
        # pylint: disable=not-callable
        tr_histories = TransformerHistories(**histories)
        # pylint: enable=not-callable

        return LoopState(histories=tr_histories,
                         constants=[],
                         feedables=default_ls.feedables)
    def train_loop_result(self) -> LoopState:
        # We process all decoding the steps together during training.
        # However, we still want to pretend that a proper decoding_loop
        # was called.
        decoder_ls = AutoregressiveDecoder.get_initial_loop_state(self)

        input_sequence = self.embed_input_symbols(self.train_input_symbols)
        input_mask = tf.transpose(self.train_mask)

        last_layer = self.layer(
            self.depth, input_sequence, input_mask)

        tr_feedables = TransformerFeedables(
            input_sequence=input_sequence,
            input_mask=tf.expand_dims(input_mask, -1))

        # t_states shape: (batch, time, channels)
        # dec_w shape: (channels, vocab)
        last_layer_shape = tf.shape(last_layer.temporal_states)
        last_layer_states = tf.reshape(
            last_layer.temporal_states,
            [-1, last_layer_shape[-1]])

        # shape (batch, time, vocab)
        logits = tf.reshape(
            tf.matmul(last_layer_states, self.decoding_w),
            [last_layer_shape[0], last_layer_shape[1], len(self.vocabulary)])
        logits += tf.reshape(self.decoding_b, [1, 1, -1])

        # TODO: record histories properly
        tr_histories = tf.zeros([])
        # tr_histories = TransformerHistories(
        #    self_attention_histories=[
        #        empty_multi_head_loop_state(self.batch_size,
        #                                    self.n_heads_self)
        #        for a in range(self.depth)],
        #    encoder_attention_histories=[
        #        empty_multi_head_loop_state(self.batch_size,
        #                                    self.n_heads_enc)
        #        for a in range(self.depth)])

        feedables = DecoderFeedables(
            step=last_layer_shape[1],
            finished=tf.ones([self.batch_size], dtype=tf.bool),
            embedded_input=self.embed_input_symbols(tf.tile(
                [END_TOKEN_INDEX], [self.batch_size])),
            other=tr_feedables)

        histories = DecoderHistories(
            logits=tf.transpose(logits, perm=[1, 0, 2]),
            output_states=tf.transpose(
                last_layer.temporal_states, [1, 0, 2]),
            output_mask=self.train_mask,
            output_symbols=self.train_inputs,
            other=tr_histories)

        return LoopState(
            feedables=feedables,
            histories=histories,
            constants=decoder_ls.constants)