def call(self, input_ids, past=None, training=False): results = {} input_ids = tf.cast(input_ids, tf.int32) batch_size = get_shape_list(input_ids)[0] past_length = 0 if past is None else get_shape_list(past)[-2] inputs_embeds = self.token_embedding(input_ids) position_embeds = self.posembedding( self.position_ids(input_ids, past_length)) hidden_states = inputs_embeds + position_embeds hidden_states = self.embedding_drop(hidden_states, training=training) presents = [] pasts = tf.unstack( past, axis=1) if past is not None else [None] * self.num_hidden_layers for i, (block, layer_past) in enumerate(zip(self.encoder_layers, pasts)): hidden_states, present = block(hidden_states, layer_past, training) presents.append(present) results['presents'] = tf.stack(presents, axis=1) output = self.ln_f(hidden_states) output_flat = tf.reshape(output, [-1, self.hidden_size]) logits = tf.matmul(output_flat, self.token_embedding.embedding_table, transpose_b=True) logits = tf.reshape(logits, [batch_size, -1, self.vocab_size]) results['logits'] = logits return results
def call(self, from_tensor, layer_past=None, is_training=True): """ :param from_tensor: [B,T,H] :param layer_past: :param attention_mask: :param head_mask: :param is_training: :return: """ from_shape = get_shape_list(from_tensor, expected_rank=[3]) self.batch_size = from_shape[0] self.from_seq_length = from_shape[1] from_tensor = reshape_to_matrix(from_tensor) # [B*T,Dim] output = self.c_att(from_tensor) # [B*T,3*N*H] q, k, v = tf.split(output, 3, axis=1) # q, k, v = tf.split(output, 3, axis=2) # [B,N,T,H] q = transpose_for_scores(q, self.batch_size, self.num_attention_heads, self.from_seq_length, self.size_per_head) k = transpose_for_scores(k, self.batch_size, self.num_attention_heads, self.from_seq_length, self.size_per_head) v = transpose_for_scores(v, self.batch_size, self.num_attention_heads, self.from_seq_length, self.size_per_head) present = tf.stack([k, v], axis=1) if layer_past is not None: past_key, past_value = tf.unstack(layer_past, axis=1) k = tf.concat([past_key, k], axis=-2) v = tf.concat([past_value, v], axis=-2) # 'new_embeddings = [B, N, T, T]' distance = tf.linalg.matmul(q, k, transpose_b=True) if self.scale: distance = distance * tf.math.rsqrt(float(get_shape_list(v)[-1])) _, _, from_length, to_length = get_shape_list(distance) distance_b = self.causal_attention_mask(from_length, to_length, dtype=distance.dtype) distance_b = tf.reshape(distance_b, [1, 1, from_length, to_length]) distance = distance * distance_b - 1e10 * (1 - distance_b) attention_probs = self.softmax(distance, axis=-1) attention_probs = self.drop_out(attention_probs, training=is_training) context_layer = tf.linalg.matmul(attention_probs, v) context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) c_shape = get_shape_list(context_layer) # [B,T,N,H] > [B*T,N*H] context_layer = tf.reshape(context_layer, [c_shape[0] * c_shape[1]] + [c_shape[-2] * c_shape[-1]]) output = self.c_proj(context_layer) # output = tf.reshape(output, [c_shape[0], c_shape[1], -1]) output = self.resid_out(output, training=is_training) return output, present
def call(self, input_tensor, attention_mask=None, is_training=True): if self.hidden_size % self.num_attention_heads != 0: raise ValueError( "The hidden size must be the integer multiple of num attention heads" ) input_shape = get_shape_list(input_tensor) input_tensor = reshape_to_matrix(input_tensor) # [15,768] # print(attention_mask) # [3,5,5] # -------------------------------------------------------------------------------------- with tf.keras.backend.name_scope("attention"): attention_heads = [] attention_head = self.attention_layer(input_tensor, input_tensor, attention_mask) attention_heads.append(attention_head) if len(attention_heads) == 1: attention_output = attention_heads[0] else: attention_output = tf.concat(attention_heads, axis=-1) with tf.keras.backend.name_scope("output"): attention_output = self.attention_output_layer( attention_output) attention_output = self.dropout(attention_output, training=is_training) attention_output = self.out_layer_norm(attention_output + input_tensor) with tf.keras.backend.name_scope("intermediate"): intermediate_output = self.inter_output(attention_output) with tf.keras.backend.name_scope("output"): layer_output = self.layer_out(intermediate_output) layer_output = self.dropout(layer_output, training=is_training) layer_output = self.layer_norm(layer_output + attention_output) layer_output = reshape_from_matrix(layer_output, input_shape) return layer_output
def call(self, inputs, is_training=True): input_ids, token_type_ids, input_mask = tf.split(inputs, 3, 0) input_ids = tf.cast(tf.squeeze(input_ids, axis=0), tf.int32) token_type_ids = tf.cast(tf.squeeze(token_type_ids, axis=0), tf.int32) input_mask = tf.cast(tf.squeeze(input_mask, axis=0), tf.int32) input_shape = get_shape_list(input_ids) batch_size = input_shape[0] seq_length = input_shape[1] if input_mask is None: input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) if token_type_ids is None: token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) self.embedding_output = self.token_embedding(input_ids) self.embedding_output = self.segposembedding(self.embedding_output, token_type_ids, is_training) with tf.keras.backend.name_scope("encoder"): attention_mask = create_attention_mask_from_input_mask( input_ids, input_mask) self.all_layer_outputs = [] layer_encode_output = self.embedding_output # print(layer_encode_output)#[3,5,768] for encoder_layer in self.encoder_layers: layer_encode_input = layer_encode_output layer_encode_output = encoder_layer(layer_encode_input, attention_mask, is_training) self.all_layer_outputs.append(layer_encode_output) self.sequence_output = layer_encode_output return self
def einsum_via_matmul(input_tensor, w, num_inner_dims): input_shape = get_shape_list(input_tensor) w_shape = get_shape_list(w) batch_dims = input_shape[:-num_inner_dims] inner_dims = input_shape[-num_inner_dims:] outer_dims = w_shape[num_inner_dims:] inner_dim = np.prod(inner_dims) outer_dim = np.prod(outer_dims) if num_inner_dims > 1: input_tensor = tf.reshape(input_tensor, batch_dims + [inner_dim]) if len(w_shape) > 2: w = tf.reshape(w, [inner_dim, outer_dim]) ret = tf.matmul(input_tensor, w) if len(outer_dims) > 1: ret = tf.reshape(ret, batch_dims + outer_dims) return ret
def call(self, x): bz, sl = get_shape_list(x)[:2] x = tf.reshape(x, [-1, self.nx]) w = tf.reshape(self.weight, [-1, self.nf]) x = tf.matmul(x, w) + self.bias x = tf.reshape(x, [bz, sl, self.nf]) return x
def call(self, inputs, is_training=True): input_ids, token_type_ids, input_mask = tf.split(inputs, 3, 0) input_ids = tf.cast(tf.squeeze(input_ids, axis=0), tf.int32) token_type_ids = tf.cast(tf.squeeze(token_type_ids, axis=0), tf.int32) input_mask = tf.cast(tf.squeeze(input_mask, axis=0), tf.int32) input_shape = get_shape_list(input_ids) batch_size = input_shape[0] seq_length = input_shape[1] if input_mask is None: input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) if token_type_ids is None: token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) with tf.keras.backend.name_scope("bert"): self.embedding_output = self.token_embedding(input_ids) self.embedding_output = self.segposembedding(self.embedding_output, token_type_ids, is_training) with tf.keras.backend.name_scope("encoder"): input_shape = get_shape_list(self.embedding_output, expected_rank=3) input_width = input_shape[2] self.all_layer_outputs = [] if input_width != self.hidden_size: prev_output = self.shape_change(self.embedding_output) else: prev_output = self.embedding_output with tf.keras.backend.name_scope("transformer"): for i in range(self.num_hidden_layers): group_idx = int(i / self.num_hidden_layers * self.num_hidden_groups) with tf.keras.backend.name_scope("group_%d" % group_idx): layer_output = prev_output for inner_group_idx in range(self.inner_group_num): # with tf.keras.backend.name_scope("layer_%d" % i): # for encoder_layer in encoder_layers: layer_output = self.encoder_layer(layer_output, input_mask, is_training) prev_output = layer_output self.all_layer_outputs.append(layer_output) self.sequence_output = layer_output return self
def call(self, input_ids): if input_ids.shape.ndims == 2: input_ids = tf.expand_dims(input_ids, axis=[-1]) flat_input_ids = tf.reshape(input_ids, [-1]) if self.use_one_hot_embedding: one_hot_input_ids = tf.keras.backend.one_hot( flat_input_ids, self.vocab_size) output = tf.linalg.matmul(one_hot_input_ids, self.embedding_table) else: output = tf.gather(self.embedding_table, flat_input_ids) input_shape = get_shape_list(input_ids) output = tf.reshape( output, input_shape[0:-1] + [input_shape[-1] * self.embedding_size]) return output
def call(self, input_tensor, token_type_ids=None, is_training=True): inputshape = get_shape_list(input_tensor, expected_rank=3) batch_size = inputshape[0] seq_length = inputshape[1] width = inputshape[2] output = input_tensor # segment features if self.use_token_type: if token_type_ids is None: raise ValueError( "token_type_ids must be specified if use_token_type is True" ) if self.use_one_hot_embedding: flat_token_type_ids = tf.reshape(token_type_ids, [-1]) one_hot_ids = tf.one_hot(flat_token_type_ids, depth=self.token_type_vocab_size) token_type_embeddings = tf.linalg.matmul( one_hot_ids, self.token_type_table) token_type_embeddings = tf.reshape( token_type_embeddings, [batch_size, seq_length, width]) else: token_type_embeddings = tf.gather(self.token_type_table, token_type_ids) output += token_type_embeddings # position features if self.use_position_embeddings: position_embeddings = tf.slice(self.full_position_embeddings, [0, 0], [seq_length, -1]) # num_dims = len(output.shape.as_list()) num_dims = len(output.shape.as_list()) position_broadcast_shape = [] for _ in range(num_dims - 2): position_broadcast_shape.append(1) position_broadcast_shape.extend([seq_length, width]) position_embeddings = tf.reshape(position_embeddings, position_broadcast_shape) output += position_embeddings output = self.layer_norm(output) # in official work they not use training output = self.drop_out(output, training=is_training) return output
def call(self, from_tensor, to_tensor=None, attention_mask=None, is_training=True): from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) if len(from_shape) != len(to_shape): raise ValueError( "The rank of `from_tensor` must match the rank of `to_tensor`." ) if len(from_shape) == 3: self.batch_size = from_shape[0] self.from_seq_length = from_shape[1] self.to_seq_length = to_shape[1] elif len(from_shape) == 2: if (self.batch_size is None or self.from_seq_length is None or self.to_seq_length is None): raise ValueError( "When passing in rank 2 tensors to attention_layer, the values " "for `batch_size`, `from_seq_length`, and `to_seq_length` " "must all be specified.") # `query_layer` = [B, F, N, H] q = self.q(from_tensor) # `key_layer` = [B, T, N, H] k = self.k(to_tensor) # `value_layer` = [B, T, N, H] v = self.v(to_tensor) q = tf.transpose(q, [0, 2, 1, 3]) k = tf.transpose(k, [0, 2, 1, 3]) v = tf.transpose(v, [0, 2, 1, 3]) if attention_mask is not None: attention_mask = tf.reshape( attention_mask, [self.batch_size, 1, self.to_seq_length, 1]) # 'new_embeddings = [B, N, F, H]' logits = tf.linalg.matmul(q, k, transpose_b=True) logits = tf.multiply(logits, 1.0 / math.sqrt(float(get_shape_list(q)[-1]))) if attention_mask is not None: # `attention_mask` = [B, T] from_shape = get_shape_list(q) if len(from_shape) == 4: broadcast_ones = tf.ones([from_shape[0], 1, from_shape[2], 1], tf.float32) elif len(from_shape) == 5: # from_shape = [B, N, Block_num, block_size, depth]# broadcast_ones = tf.ones( [from_shape[0], 1, from_shape[2], from_shape[3], 1], tf.float32) attention_mask = tf.matmul(broadcast_ones, tf.cast(attention_mask, tf.float32), transpose_b=True) adder = (1.0 - attention_mask) * -10000.0 logits += adder attention_probs = tf.math.softmax(logits, name="attention_probs") attention_probs = self.drop_out(attention_probs, training=is_training) context_layer = tf.linalg.matmul(attention_probs, v) return tf.transpose(context_layer, [0, 2, 1, 3])