예제 #1
0
def model(hparams, X, past=None, scope='model', reuse=False):
    with tf.variable_scope(scope, reuse=reuse):
        results = {}
        batch, sequence = shape_list(X)
        wpe = tf.get_variable(
            'wpe', [hparams.n_ctx, hparams.n_embd],
            initializer=tf.random_normal_initializer(stddev=0.01))
        wte = tf.get_variable(
            'wte', [hparams.n_vocab, hparams.n_embd],
            initializer=tf.random_normal_initializer(stddev=0.02))
        past_length = 0 if past is None else tf.shape(past)[-2]
        h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))
        # Transformer
        presents = []
        pasts = tf.unstack(
            past, axis=1) if past is not None else [None] * hparams.n_layer
        assert len(pasts) == hparams.n_layer
        for layer, past in enumerate(pasts):
            h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
            presents.append(present)
        presents = tf.stack(presents, axis=1)
        presents.set_shape(past_shape(hparams=hparams, batch_size=None))
        results['presents'] = presents
        h = norm(h, 'ln_f')
        # Language model loss.  Do tokens <n predict token n?
        h_flat = tf.reshape(h, [batch * sequence, hparams.n_embd])
        logits = tf.matmul(h_flat, wte, transpose_b=True)
        logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])
        results['logits'] = logits
        return results
def model(*, hparams, X, src_seq_mask, scope='model', reuse=False):
    with tf.variable_scope(scope, reuse=reuse):
        results = {}
        batch, sequence = shape_list(X)
        wpe = tf.get_variable(
            'wpe', [hparams.n_ctx, hparams.n_embd],
            initializer=tf.random_normal_initializer(stddev=0.01))
        wte = tf.get_variable(
            'wte', [hparams.n_vocab, hparams.n_embd],
            initializer=tf.random_normal_initializer(stddev=0.02))
        past_length = 0
        h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))
        # Transformer
        presents = []
        for layer in range(hparams.n_layer):
            h, present = block(h,
                               'h%d' % layer,
                               past=None,
                               hparams=hparams,
                               src_seq_mask=src_seq_mask)
            presents.append(present)
        results['present'] = tf.stack(presents, axis=1)
        h = norm(h, 'ln_f')
        # Language model loss.  Do tokens <n predict token n?
        h_flat = tf.reshape(h, [batch * sequence, hparams.n_embd])
        logits = tf.matmul(h_flat, wte, transpose_b=True)
        logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])
        results['logits'] = logits
        return results
 def decode_all(self, tokens, past_list, enc_h_list):
     """for multiple sources, like GPT-HA, if len(past_list)==1, it is a simple GPTEncoder-Decoder model"""
     with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE):
         with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
             results = {}
             if type(past_list) != list:
                 past_list = [past_list]
             batch, sequence = shape_list(tokens)
             #past_length = 0
             all_past_length = [
                 0 if past_list[0] is None else tf.shape(past_list[0])[-2]
             ]
             past_length = tf.reduce_max(tf.stack(all_past_length, axis=0),
                                         axis=0)
             h = tf.gather(self.wte, tokens) + tf.gather(
                 self.wpe, positions_for(tokens, past_length))
             values_present = {}
             for i in range(0, self.hparams.n_layer):
                 querys = h
                 values_h = []
                 for j in range(0, len(past_list)):
                     past = past_list[j]
                     pasts = tf.unstack(past,
                                        axis=1) if past is not None else [
                                            None
                                        ] * self.hparams.n_layer
                     assert len(pasts) == self.hparams.n_layer
                     h, present = block(querys,
                                        'h%d' % i,
                                        past=pasts[i],
                                        hparams=self.hparams)
                     values_h.append(h)
                     if j in values_present:
                         values_present[j].append(present)
                     else:
                         values_present[j] = [present]
                 enc_h_all = tf.concat(enc_h_list, axis=1)
                 attn_score = tf.tensordot(querys, self.attn_w, axes=(2, 0))
                 attn_score = tf.matmul(
                     attn_score,
                     tf.transpose(enc_h_all,
                                  perm=(0, 2, 1)))  # batch*seq*context_num
                 attn_score = tf.nn.softmax(attn_score, axis=2)
                 val_h_cat = tf.stack(values_h, axis=2)
                 val_h_cat = tf.expand_dims(attn_score, axis=3) * val_h_cat
                 val_h_cat = tf.reduce_sum(val_h_cat, axis=2)
                 h = val_h_cat
             for j in range(0, len(past_list)):
                 values_present[j] = tf.stack(values_present[j], axis=1)
                 past_list[j] = tf.concat([past_list[j], values_present[j]],
                                          axis=-2)
             h = norm(h, 'ln_f')
             # Language model loss.  Do tokens <n predict token n?
             h_flat = tf.reshape(h, [batch * sequence, self.hparams.n_embd])
             logits = tf.matmul(h_flat, self.wte, transpose_b=True)
             logits = tf.reshape(logits,
                                 [batch, sequence, self.hparams.n_vocab])
             results['logits'] = logits
             return results
 def decode_one_step(self, hparams: "no use, only for consistency of api",
                     input_token, past_dec: list):
     with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE):
         with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
             all_past_length = [
                 0 if past_dec[j] is None else tf.shape(past_dec[j])[-2]
                 for j in range(0, len(past_dec))
             ]
             past_length = tf.reduce_max(tf.stack(all_past_length, axis=0),
                                         axis=0)
             h = tf.gather(self.wte, input_token) + tf.gather(
                 self.wpe, positions_for(input_token, past_length))
             results = {}
             batch, sequence = shape_list(input_token)
             values_present = {}
             for i in range(0, self.hparams.n_layer):
                 querys = h
                 values_h = []
                 for j in range(0, len(past_dec)):
                     dec_pasts = tf.unstack(
                         past_dec[j],
                         axis=1) if past_dec[j] is not None else [
                             None
                         ] * self.hparams.n_layer  #
                     h, present = block(querys,
                                        'h%d' % i,
                                        past=dec_pasts[i],
                                        hparams=self.hparams)
                     values_h.append(h)
                     if j in values_present:
                         values_present[j].append(present)
                     else:
                         values_present[j] = [present]
                 attn_score = tf.tensordot(querys, self.attn_w, axes=(2, 0))
                 attn_score = tf.matmul(
                     attn_score,
                     tf.transpose(self.enc_h_all,
                                  perm=(0, 2, 1)))  # batch*seq*context_num
                 attn_score = tf.nn.softmax(attn_score, axis=2)
                 val_h_cat = tf.stack(values_h, axis=2)
                 val_h_cat = tf.expand_dims(attn_score, axis=3) * val_h_cat
                 val_h_cat = tf.reduce_sum(val_h_cat, axis=2)
                 h = val_h_cat
             for j in range(0, len(past_dec)):
                 values_present[j] = tf.stack(values_present[j], axis=1)
                 past_dec[j] = tf.concat([past_dec[j], values_present[j]],
                                         axis=-2)
             h = norm(h, 'ln_f')
             # Language model loss.  Do tokens <n predict token n?
             h_flat = tf.reshape(h, [batch * sequence, self.hparams.n_embd])
             logits = tf.matmul(h_flat, self.wte, transpose_b=True)
             logits = tf.reshape(logits,
                                 [batch, sequence, self.hparams.n_vocab])
             results['logits'] = logits
             results['presents'] = past_dec
             return results
 def encode_which_outputs_all_layer_h(self,
                                      X,
                                      h_len,
                                      past=None,
                                      scope='encoder',
                                      reuse=tf.AUTO_REUSE):
     with tf.variable_scope(scope, reuse=reuse):
         with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
             # Transformer
             wpe = tf.get_variable(
                 'wpe', [self.hparams.n_ctx, self.hparams.n_embd],
                 initializer=tf.random_normal_initializer(stddev=0.01))
             wte = tf.get_variable(
                 'wte', [self.hparams.n_vocab, self.hparams.n_embd],
                 initializer=tf.random_normal_initializer(stddev=0.02))
             past_length = 0 if past is None else tf.shape(past)[-2]
             h = tf.gather(wte, X, name='gggggg1') + tf.gather(
                 wpe, positions_for(X, past_length), name='ggggggg2')
             presents = []
             pasts = tf.unstack(
                 past, axis=1
             ) if past is not None else [None] * self.hparams.n_layer
             assert len(pasts) == self.hparams.n_layer
             all_h = []
             final_id = h_len - 1
             for layer, past_one in enumerate(pasts):
                 h, present = block(h,
                                    'h%d' % layer,
                                    past=past_one,
                                    hparams=self.hparams)
                 presents.append(present)
                 all_h.append(
                     gather_2d(h, tf.expand_dims(final_id, axis=1))[:,
                                                                    0, :])
             presents = tf.stack(presents, axis=1)
             h = norm(h, 'ln_f')
             all_h.append(
                 gather_2d(h, tf.expand_dims(final_id, axis=1))[:, 0, :])
             target_mask = tf.sequence_mask(
                 h_len, maxlen=tf.shape(h)[1],
                 dtype=tf.float32)  #如果是h_len-1则把sentence token给mask掉
             target_mask = tf.expand_dims(target_mask, 2)
             encode_out = tf.transpose(presents, perm=(0, 4, 2, 3, 1, 5))
             ori_enc_shape = tf.shape(encode_out)
             encode_out = tf.reshape(encode_out,
                                     shape=(tf.shape(presents)[0],
                                            tf.shape(presents)[4], -1))
             encode_out = tf.multiply(encode_out, target_mask)
             encode_out = tf.reshape(encode_out, shape=ori_enc_shape)
             encode_out = tf.transpose(encode_out, perm=(0, 4, 2, 3, 1, 5))
             encode_out.set_shape(
                 past_shape(hparams=self.hparams, batch_size=None))
             return encode_out, all_h