Example #1
0
        def inner_loop(i, hit_eos, decoded_ids):

            tgt_embed = self.tgt_embedding.encode(decoded_ids)
            T = get_shape_as_list(tgt_embed)[1]
            tgt_mask = subsequent_mask(T)
            scope = 'TransformerDecoder'
            h = transformer_decoder_stack(tgt_embed, src_enc, src_mask,
                                          tgt_mask, num_heads, pdrop, scale,
                                          layers, activation_type, scope, d_ff)

            vsz = self.tgt_embedding.vsz
            do_weight_tying = bool(kwargs.get('tie_weights', True))  # False
            hsz = get_shape_as_list(h)[-1]
            h = tf.reshape(h, [-1, hsz])
            if do_weight_tying and hsz == self.tgt_embedding.get_dsz():
                with tf.variable_scope(self.tgt_embedding.scope, reuse=True):
                    W = tf.get_variable("W")
                    outputs = tf.matmul(h, W, transpose_b=True, name="logits")
            else:
                vocab_w = tf.get_variable("vocab_w", [hsz, vsz],
                                          dtype=tf.float32)
                vocab_b = tf.get_variable("vocab_b", [vsz], dtype=tf.float32)
                outputs = tf.nn.xw_plus_b(h, vocab_w, vocab_b, name="logits")

            preds = tf.reshape(outputs, [B, T, vsz])
            next_id = tf.argmax(preds, axis=-1)[:, -1]
            hit_eos |= tf.equal(next_id, Offsets.EOS)
            next_id = tf.reshape(next_id, [B, 1])

            decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
            return i + 1, hit_eos, decoded_ids
Example #2
0
        def inner_loop(i, hit_eos, decoded_ids):

            tgt_embed = self.tgt_embedding.encode(decoded_ids)
            T = get_shape_as_list(tgt_embed)[1]
            tgt_mask = subsequent_mask(T)
            scope = 'TransformerDecoder'
            h = transformer_decoder_stack(src_enc, tgt_embed, src_mask, tgt_mask, num_heads, pdrop, scale, layers, activation_type, scope, d_ff)

            vsz = self.tgt_embedding.vsz
            do_weight_tying = bool(kwargs.get('tie_weights', True))  # False
            hsz = get_shape_as_list(h)[-1]
            h = tf.reshape(h, [-1, hsz])
            if do_weight_tying and hsz == self.tgt_embedding.get_dsz():
                with tf.variable_scope(self.tgt_embedding.scope, reuse=True):
                    W = tf.get_variable("W")
                    outputs = tf.matmul(h, W, transpose_b=True, name="logits")
            else:
                vocab_w = tf.get_variable("vocab_w", [hsz, vsz], dtype=tf.float32)
                vocab_b = tf.get_variable("vocab_b", [vsz], dtype=tf.float32)
                outputs = tf.nn.xw_plus_b(h, vocab_w, vocab_b, name="logits")

            preds = tf.reshape(outputs, [B, T, vsz])
            next_id = tf.argmax(preds, axis=-1)[:, -1]
            hit_eos |= tf.equal(next_id, Offsets.EOS)
            next_id = tf.reshape(next_id, [B, 1])

            decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
            return i + 1, hit_eos, decoded_ids
Example #3
0
    def decode(self,
               encoder_outputs,
               src_len,
               tgt_len,
               pdrop,
               layers=1,
               scope='TransformerDecoder',
               num_heads=4,
               scale=True,
               activation_type='relu',
               d_ff=None,
               **kwargs):
        """self.best is [T, B]"""
        src_enc = encoder_outputs.output
        if hasattr(encoder_outputs, 'src_mask'):
            src_mask = encoder_outputs.src_mask
        else:
            T = get_shape_as_list(src_enc)[1]
            src_mask = tf.sequence_mask(src_len, T, dtype=tf.float32)
        tgt_embed = self.tgt_embedding.encode(kwargs.get('tgt'))
        T = get_shape_as_list(tgt_embed)[1]
        tgt_mask = subsequent_mask(T)
        scope = 'TransformerDecoder'
        h = transformer_decoder_stack(tgt_embed, src_enc, src_mask, tgt_mask,
                                      num_heads, pdrop, scale, layers,
                                      activation_type, scope, d_ff)

        vsz = self.tgt_embedding.vsz
        do_weight_tying = bool(kwargs.get('tie_weights', True))  # False
        hsz = get_shape_as_list(h)[-1]
        if do_weight_tying and hsz == self.tgt_embedding.get_dsz():
            h = tf.reshape(h, [-1, hsz])
            with tf.variable_scope(self.tgt_embedding.scope, reuse=True):
                W = tf.get_variable("W")
                outputs = tf.matmul(h, W, transpose_b=True, name="logits")
        else:
            h = tf.reshape(h, [-1, hsz])
            vocab_w = tf.get_variable("vocab_w", [hsz, vsz], dtype=tf.float32)
            vocab_b = tf.get_variable("vocab_b", [vsz], dtype=tf.float32)
            outputs = tf.nn.xw_plus_b(h, vocab_w, vocab_b, name="logits")
        self.preds = tf.transpose(tf.reshape(outputs, [-1, T, vsz]), [1, 0, 2])
        best = tf.argmax(self.preds, -1)
        self.output(best)
Example #4
0
    def decode(self, encoder_outputs,
               src_len,
               tgt_len,
               pdrop,
               layers=1,
               scope='TransformerDecoder',
               num_heads=4,
               scale=True,
               activation_type='relu',
               d_ff=None, **kwargs):
        """self.best is [T, B]"""
        src_enc = encoder_outputs.output
        if hasattr(encoder_outputs, 'src_mask'):
            src_mask = encoder_outputs.src_mask
        else:
            T = get_shape_as_list(src_enc)[1]
            src_mask = tf.sequence_mask(src_len, T, dtype=tf.float32)
        tgt_embed = self.tgt_embedding.encode(kwargs.get('tgt'))
        T = get_shape_as_list(tgt_embed)[1]
        tgt_mask = subsequent_mask(T)
        scope = 'TransformerDecoder'
        h = transformer_decoder_stack(src_enc, tgt_embed, src_mask, tgt_mask, num_heads, pdrop, scale, layers, activation_type, scope, d_ff)

        vsz = self.tgt_embedding.vsz
        do_weight_tying = bool(kwargs.get('tie_weights', True))  # False
        hsz = get_shape_as_list(h)[-1]
        if do_weight_tying and hsz == self.tgt_embedding.get_dsz():
            h = tf.reshape(h, [-1, hsz])
            with tf.variable_scope(self.tgt_embedding.scope, reuse=True):
                W = tf.get_variable("W")
                outputs = tf.matmul(h, W, transpose_b=True, name="logits")
        else:
            h = tf.reshape(h, [-1, hsz])
            vocab_w = tf.get_variable("vocab_w", [hsz, vsz], dtype=tf.float32)
            vocab_b = tf.get_variable("vocab_b", [vsz], dtype=tf.float32)
            outputs = tf.nn.xw_plus_b(h, vocab_w, vocab_b, name="logits")
        self.preds = tf.transpose(tf.reshape(outputs, [-1, T, vsz]), [1, 0, 2])
        best = tf.argmax(self.preds, -1)
        self.output(best)