def model(hparams, X, past=None, scope='model', reuse=False): with tf.variable_scope(scope, reuse=reuse): results = {} batch, sequence = shape_list(X) wpe = tf.get_variable( 'wpe', [hparams.n_ctx, hparams.n_embd], initializer=tf.random_normal_initializer(stddev=0.01)) wte = tf.get_variable( 'wte', [hparams.n_vocab, hparams.n_embd], initializer=tf.random_normal_initializer(stddev=0.02)) past_length = 0 if past is None else tf.shape(past)[-2] h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length)) # Transformer presents = [] pasts = tf.unstack( past, axis=1) if past is not None else [None] * hparams.n_layer assert len(pasts) == hparams.n_layer for layer, past in enumerate(pasts): h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) presents.append(present) presents = tf.stack(presents, axis=1) presents.set_shape(past_shape(hparams=hparams, batch_size=None)) results['presents'] = presents h = norm(h, 'ln_f') # Language model loss. Do tokens <n predict token n? h_flat = tf.reshape(h, [batch * sequence, hparams.n_embd]) logits = tf.matmul(h_flat, wte, transpose_b=True) logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab]) results['logits'] = logits return results
def model(*, hparams, X, src_seq_mask, scope='model', reuse=False): with tf.variable_scope(scope, reuse=reuse): results = {} batch, sequence = shape_list(X) wpe = tf.get_variable( 'wpe', [hparams.n_ctx, hparams.n_embd], initializer=tf.random_normal_initializer(stddev=0.01)) wte = tf.get_variable( 'wte', [hparams.n_vocab, hparams.n_embd], initializer=tf.random_normal_initializer(stddev=0.02)) past_length = 0 h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length)) # Transformer presents = [] for layer in range(hparams.n_layer): h, present = block(h, 'h%d' % layer, past=None, hparams=hparams, src_seq_mask=src_seq_mask) presents.append(present) results['present'] = tf.stack(presents, axis=1) h = norm(h, 'ln_f') # Language model loss. Do tokens <n predict token n? h_flat = tf.reshape(h, [batch * sequence, hparams.n_embd]) logits = tf.matmul(h_flat, wte, transpose_b=True) logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab]) results['logits'] = logits return results
def decode_all(self, tokens, past_list, enc_h_list): """for multiple sources, like GPT-HA, if len(past_list)==1, it is a simple GPTEncoder-Decoder model""" with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE): with tf.variable_scope('model', reuse=tf.AUTO_REUSE): results = {} if type(past_list) != list: past_list = [past_list] batch, sequence = shape_list(tokens) #past_length = 0 all_past_length = [ 0 if past_list[0] is None else tf.shape(past_list[0])[-2] ] past_length = tf.reduce_max(tf.stack(all_past_length, axis=0), axis=0) h = tf.gather(self.wte, tokens) + tf.gather( self.wpe, positions_for(tokens, past_length)) values_present = {} for i in range(0, self.hparams.n_layer): querys = h values_h = [] for j in range(0, len(past_list)): past = past_list[j] pasts = tf.unstack(past, axis=1) if past is not None else [ None ] * self.hparams.n_layer assert len(pasts) == self.hparams.n_layer h, present = block(querys, 'h%d' % i, past=pasts[i], hparams=self.hparams) values_h.append(h) if j in values_present: values_present[j].append(present) else: values_present[j] = [present] enc_h_all = tf.concat(enc_h_list, axis=1) attn_score = tf.tensordot(querys, self.attn_w, axes=(2, 0)) attn_score = tf.matmul( attn_score, tf.transpose(enc_h_all, perm=(0, 2, 1))) # batch*seq*context_num attn_score = tf.nn.softmax(attn_score, axis=2) val_h_cat = tf.stack(values_h, axis=2) val_h_cat = tf.expand_dims(attn_score, axis=3) * val_h_cat val_h_cat = tf.reduce_sum(val_h_cat, axis=2) h = val_h_cat for j in range(0, len(past_list)): values_present[j] = tf.stack(values_present[j], axis=1) past_list[j] = tf.concat([past_list[j], values_present[j]], axis=-2) h = norm(h, 'ln_f') # Language model loss. Do tokens <n predict token n? h_flat = tf.reshape(h, [batch * sequence, self.hparams.n_embd]) logits = tf.matmul(h_flat, self.wte, transpose_b=True) logits = tf.reshape(logits, [batch, sequence, self.hparams.n_vocab]) results['logits'] = logits return results
def decode_one_step(self, hparams: "no use, only for consistency of api", input_token, past_dec: list): with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE): with tf.variable_scope('model', reuse=tf.AUTO_REUSE): all_past_length = [ 0 if past_dec[j] is None else tf.shape(past_dec[j])[-2] for j in range(0, len(past_dec)) ] past_length = tf.reduce_max(tf.stack(all_past_length, axis=0), axis=0) h = tf.gather(self.wte, input_token) + tf.gather( self.wpe, positions_for(input_token, past_length)) results = {} batch, sequence = shape_list(input_token) values_present = {} for i in range(0, self.hparams.n_layer): querys = h values_h = [] for j in range(0, len(past_dec)): dec_pasts = tf.unstack( past_dec[j], axis=1) if past_dec[j] is not None else [ None ] * self.hparams.n_layer # h, present = block(querys, 'h%d' % i, past=dec_pasts[i], hparams=self.hparams) values_h.append(h) if j in values_present: values_present[j].append(present) else: values_present[j] = [present] attn_score = tf.tensordot(querys, self.attn_w, axes=(2, 0)) attn_score = tf.matmul( attn_score, tf.transpose(self.enc_h_all, perm=(0, 2, 1))) # batch*seq*context_num attn_score = tf.nn.softmax(attn_score, axis=2) val_h_cat = tf.stack(values_h, axis=2) val_h_cat = tf.expand_dims(attn_score, axis=3) * val_h_cat val_h_cat = tf.reduce_sum(val_h_cat, axis=2) h = val_h_cat for j in range(0, len(past_dec)): values_present[j] = tf.stack(values_present[j], axis=1) past_dec[j] = tf.concat([past_dec[j], values_present[j]], axis=-2) h = norm(h, 'ln_f') # Language model loss. Do tokens <n predict token n? h_flat = tf.reshape(h, [batch * sequence, self.hparams.n_embd]) logits = tf.matmul(h_flat, self.wte, transpose_b=True) logits = tf.reshape(logits, [batch, sequence, self.hparams.n_vocab]) results['logits'] = logits results['presents'] = past_dec return results
def encode_which_outputs_all_layer_h(self, X, h_len, past=None, scope='encoder', reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=reuse): with tf.variable_scope('model', reuse=tf.AUTO_REUSE): # Transformer wpe = tf.get_variable( 'wpe', [self.hparams.n_ctx, self.hparams.n_embd], initializer=tf.random_normal_initializer(stddev=0.01)) wte = tf.get_variable( 'wte', [self.hparams.n_vocab, self.hparams.n_embd], initializer=tf.random_normal_initializer(stddev=0.02)) past_length = 0 if past is None else tf.shape(past)[-2] h = tf.gather(wte, X, name='gggggg1') + tf.gather( wpe, positions_for(X, past_length), name='ggggggg2') presents = [] pasts = tf.unstack( past, axis=1 ) if past is not None else [None] * self.hparams.n_layer assert len(pasts) == self.hparams.n_layer all_h = [] final_id = h_len - 1 for layer, past_one in enumerate(pasts): h, present = block(h, 'h%d' % layer, past=past_one, hparams=self.hparams) presents.append(present) all_h.append( gather_2d(h, tf.expand_dims(final_id, axis=1))[:, 0, :]) presents = tf.stack(presents, axis=1) h = norm(h, 'ln_f') all_h.append( gather_2d(h, tf.expand_dims(final_id, axis=1))[:, 0, :]) target_mask = tf.sequence_mask( h_len, maxlen=tf.shape(h)[1], dtype=tf.float32) #如果是h_len-1则把sentence token给mask掉 target_mask = tf.expand_dims(target_mask, 2) encode_out = tf.transpose(presents, perm=(0, 4, 2, 3, 1, 5)) ori_enc_shape = tf.shape(encode_out) encode_out = tf.reshape(encode_out, shape=(tf.shape(presents)[0], tf.shape(presents)[4], -1)) encode_out = tf.multiply(encode_out, target_mask) encode_out = tf.reshape(encode_out, shape=ori_enc_shape) encode_out = tf.transpose(encode_out, perm=(0, 4, 2, 3, 1, 5)) encode_out.set_shape( past_shape(hparams=self.hparams, batch_size=None)) return encode_out, all_h