def build_graph(self):
        graph = transformer.Graph(self.mode)
        output = graph.build(encoder_inputs=self.encoder_inputs,
                             decoder_inputs=self.decoder_inputs)

        if self.mode == tf.estimator.ModeKeys.TRAIN:
            self._build_loss(output)
            self._build_optimizer()
        else:

            def _filled_next_token(inputs, logits, decoder_index):
                tf.identity(tf.argmax(logits[0], axis=1, output_type=tf.int32),
                            f'test/pred_{decoder_index}')

                next_token = tf.slice(
                    tf.argmax(logits, axis=2, output_type=tf.int32),
                    [0, decoder_index - 1], [self.batch_size, 1])
                left_zero_pads = tf.zeros([self.batch_size, decoder_index],
                                          dtype=tf.int32)
                right_zero_pads = tf.zeros([
                    self.batch_size,
                    (Config.data.max_seq_length - decoder_index - 1)
                ],
                                           dtype=tf.int32)
                next_token = tf.concat(
                    (left_zero_pads, next_token, right_zero_pads), axis=1)

                return inputs + next_token

            encoder_outputs = graph.encoder_outputs
            decoder_inputs = _filled_next_token(self.decoder_inputs, output, 1)

            # predict output with loop. [encoder_outputs, decoder_inputs (filled next token)]
            for i in range(2, Config.data.max_seq_length):
                decoder_emb_inp = graph.build_embed(decoder_inputs,
                                                    encoder=False,
                                                    reuse=True)
                decoder_outputs = graph.build_decoder(decoder_emb_inp,
                                                      encoder_outputs,
                                                      reuse=True)
                next_output = graph.build_output(decoder_outputs, reuse=True)

                decoder_inputs = _filled_next_token(decoder_inputs,
                                                    next_output, i)

            self._build_loss(next_output)

            # slice start_token
            decoder_input_start_1 = tf.slice(
                decoder_inputs, [0, 1],
                [self.batch_size, Config.data.max_seq_length - 1])
            self.predictions = tf.concat([
                decoder_input_start_1,
                tf.zeros([self.batch_size, 1], dtype=tf.int32)
            ],
                                         axis=1)
Beispiel #2
0
    def build_graph(self):
        graph = transformer.Graph(self.mode)
        output, predictions = graph.build(encoder_inputs=self.encoder_inputs,
                             decoder_inputs=self.decoder_inputs)

        self.predictions = predictions
        if self.mode != tf.estimator.ModeKeys.PREDICT:
            self._build_loss(output)
            self._build_optimizer()
            self._build_metric()
Beispiel #3
0
 def create_network_transformer(self):
     import transformer
     from transformer.hbconfig import Config
     xx = transformer.Graph(tf.estimator.ModeKeys.TRAIN)
     Config.data.max_seq_length = self.rnn_pp_hist
     Config.data.n_classes = self.conf.n_classes
     Config.data.target_vocab_size = 2 * self.conf.n_classes
     rr = xx.build(self.inputs[0], self.inputs[0])
     out = rr[0][:, -1, :]
     loss = tf.nn.l2_loss(out - self.inputs[1])
     self.pred = out
     self.cost = loss
Beispiel #4
0
    def build_graph(self):
        graph = transformer.Graph(self.mode)
        output = graph.build(encoder_inputs=self.encoder_inputs,
                             decoder_inputs=self.decoder_inputs)

        if self.mode == tf.estimator.ModeKeys.TRAIN:
            self._build_loss(output)
            self._build_optimizer()
        else:

            def _filled_next_token(inputs, logits, decoder_index):
                next_token = tf.reshape(
                    tf.argmax(logits, axis=1, output_type=tf.int32),
                    [Config.model.batch_size, 1])
                left_zero_pads = tf.zeros(
                    [Config.model.batch_size, decoder_index], dtype=tf.int32)
                right_zero_pads = tf.zeros([
                    Config.model.batch_size,
                    (Config.data.max_seq_length - decoder_index - 1)
                ],
                                           dtype=tf.int32)
                next_token = tf.concat(
                    (left_zero_pads, next_token, right_zero_pads), axis=1)

                return inputs + next_token

            encoder_outputs = graph.encoder_outputs
            decoder_inputs = _filled_next_token(self.decoder_inputs, output, 1)
            sequence_logits = tf.reshape(
                output, [-1, 1, Config.data.target_vocab_size])

            # predict output with loop. [encoder_outputs, decoder_inputs (filled next token)]
            for i in range(2, Config.data.max_seq_length):
                decoder_emb_inp = graph.build_embed(decoder_inputs,
                                                    encoder=False,
                                                    reuse=True)
                decoder_outputs = graph.build_decoder(decoder_emb_inp,
                                                      encoder_outputs,
                                                      reuse=True)
                next_output = graph.build_output(decoder_outputs, reuse=True)

                decoder_inputs = _filled_next_token(decoder_inputs,
                                                    next_output, i)
                sequence_logits = tf.concat(
                    (sequence_logits,
                     tf.reshape(next_output,
                                [-1, 1, Config.data.target_vocab_size])),
                    axis=1)

            self._build_loss(sequence_logits)
            self.predictions = decoder_inputs