def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False): # q, k, v have shape [batch, heads, sequence, features] w = tf.matmul(q, k, transpose_b=True) if self.scale: dk = tf.cast(get_shape(k)[-1], dtype=w.dtype) # scale attention_scores w = w / tf.math.sqrt(dk) # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. _, _, nd, ns = get_shape(w) b = self.causal_attention_mask(nd, ns, dtype=w.dtype) b = tf.reshape(b, [1, 1, nd, ns]) w = w * b - 1e4 * (1 - b) if attention_mask is not None: # Apply the attention mask w = w + attention_mask w = tf.nn.softmax(w, axis=-1) w = self.attn_dropout(w, training=training) # Mask heads if we want to if head_mask is not None: w = w * head_mask outputs = [tf.matmul(w, v)] if output_attentions: outputs.append(w) return outputs
def _linear(self, inputs): """Computes logits by running inputs through a linear layer. Args: inputs: A float32 tensor with shape [batch_size, length, hidden_size] Returns: float32 tensor with shape [batch_size, length, vocab_size]. """ batch_size = get_shape(inputs)[0] length = get_shape(inputs)[1] x = tf.reshape(inputs, [-1, self.embedding_size]) logits = tf.matmul(x, self.word_embeddings, transpose_b=True) return tf.reshape(logits, [batch_size, length, self.vocab_size])
def call(self, x): bz, sl = get_shape(x)[:2] x = tf.reshape(x, [-1, self.nx]) x = tf.matmul(x, self.weight) + self.bias x = tf.reshape(x, [bz, sl, self.nf]) return x
def rel_shift(x, klen=-1): """perform relative shift to form the relative attention score.""" x_size = get_shape(x) x = tf.reshape(x, (x_size[1], x_size[0], x_size[2], x_size[3])) x = x[1:, ...] x = tf.reshape(x, (x_size[0], x_size[1] - 1, x_size[2], x_size[3])) x = x[:, 0:klen, :, :] # x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long)) return x
def _linear(self, inputs): """Computes logits by running inputs through a linear layer. Args: inputs: A float32 tensor with shape [..., hidden_size] Returns: float32 tensor with shape [..., vocab_size]. """ first_dims = get_shape(inputs)[:-1] x = tf.reshape(inputs, [-1, self.hidden_size]) logits = tf.matmul(x, self.weight, transpose_b=True) return tf.reshape(logits, first_dims + [self.vocab_size])
def call(self, y_true, y_pred): batch_size = get_shape(y_true)[0] y_true = tf.reshape(y_true, [batch_size, -1]) y_true = tf.cast(y_true, tf.float32) y_pred = tf.reshape(y_pred, [batch_size, -1]) # p_t = y_pred * y_true + (1. - y_pred) * (1. - y_true) num = tf.reduce_sum((1.0 - y_pred) * y_pred * y_true, axis=1) + self.smooth den = tf.reduce_sum((1.0 - y_pred) * y_pred + y_true, axis=1) + self.smooth loss = 1 - num / den return tf.reduce_sum(loss)
def loss(self, y_true, y_pred): """Computes the log-likelihood of tag sequences in a CRF. Args: y_true : A (batch_size, n_steps) tensor. y_pred : A (batch_size, n_steps, n_classes) tensor. Returns: loss: A scalar containing the log-likelihood of the given sequence of tag indices. """ batch_size, n_steps, _ = get_shape(y_pred) y_true = tf.cast(tf.reshape(y_true, [batch_size, n_steps]), dtype='int32') log_likelihood, self.transition_params = \ tfa.text.crf_log_likelihood(y_pred, y_true, self.sequence_length, self.transition_params) loss = tf.reduce_mean(-log_likelihood) return loss
def call(self, y_true, y_pred): y_pred_shape = get_shape(y_pred) y_true = tf.cast(y_true, tf.int32) y_true = tf.one_hot(y_true, y_pred_shape[-1]) y_true = tf.reshape(y_true, y_pred_shape) loss = softmax_focal_crossentropy( y_true, y_pred, alpha=self.alpha, gamma=self.gamma, from_logits=self.from_logits, ) return loss
def call(self, y_true, y_pred): """ :param y_pred: [batch_size, seq_len, label_size] :param y_true: :return: """ true_rank, pred_rank = y_true.shape.ndims, y_pred.shape.ndims if true_rank + 1 == pred_rank: label_num = get_shape(y_pred)[-1] y_true = tf.one_hot(y_true, label_num) true_rank = y_true.shape.ndims assert true_rank == pred_rank, \ f"For the tensor y_true, the actual tensor rank {true_rank} (shape = {get_shape(y_true)}) " \ f"is not equal to the expected tensor rank {pred_rank}" pred_shape = get_shape(y_pred) y_true = tf.reshape(y_true, [pred_shape[0], self.seq_len, self.label_num]) y_pred = tf.reshape(y_pred, [pred_shape[0], self.seq_len, self.label_num]) total_loss = 0 y_pred = tf.math.softmax(y_pred) y_pred = tf.unstack(y_pred, axis=-1) y_true = tf.unstack(y_true, axis=-1) for i in range(self.label_num): if i != self.ignore_index: dice_loss = self.binary_dce_loss(y_true[i], y_pred[i]) if self.weight is not None: assert len(self.weight) == pred_shape[-1], \ f'Expect weight shape [{pred_shape[-1]}], get[{len(self.weight)}]' dice_loss *= self.weights[i] total_loss += dice_loss return total_loss / pred_shape[-1]
def _embedding(self, inputs, training=False): """Applies embedding based on inputs tensor.""" input_ids, position_ids, token_type_ids, inputs_embeds = inputs if input_ids is not None: input_shape = get_shape(input_ids) else: input_shape = get_shape(inputs_embeds)[:-1] seq_length = input_shape[1] if position_ids is None: position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :] if token_type_ids is None: token_type_ids = tf.fill(input_shape, 0) if inputs_embeds is None: inputs_embeds = tf.gather(self.word_embeddings, input_ids) position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + position_embeddings + token_type_embeddings embeddings = self.layer_norm(embeddings) embeddings = self.dropout(embeddings, training=training) return embeddings
def rel_attn_core(self, inputs, training=False): """Core relative positional attention operations.""" q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask = inputs # content based attention score ac = tf.einsum('ibnd,jbnd->ijbn', q_head + self.r_w_bias, k_head_h) # position based attention score bd = tf.einsum('ibnd,jbnd->ijbn', q_head + self.r_r_bias, k_head_r) bd = self.rel_shift(bd, klen=get_shape(ac)[1]) # segment based attention score if seg_mat is None: ef = 0 else: ef = tf.einsum('ibnd,snd->ibns', q_head + self.r_s_bias, self.seg_embed) ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef) # merge attention scores and perform masking attn_score = (ac + bd + ef) * self.scale if attn_mask is not None: # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask if attn_mask.dtype == tf.float16: attn_score = attn_score - 65500 * attn_mask else: attn_score = attn_score - 1e30 * attn_mask # attention probability attn_prob = tf.nn.softmax(attn_score, axis=1) attn_prob = self.dropout(attn_prob, training=training) # Mask heads if we want to if head_mask is not None: attn_prob = attn_prob * head_mask # attention output attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h) if self.output_attentions: return attn_vec, attn_prob return attn_vec
def call(self, inputs, **kwargs): # inputs # [batch_size, 2] entity_span_start = inputs['entity_span_start'] entity_span_end = inputs['entity_span_end'] # entity_labels = inputs['entity_labels'] batch_size, max_entity_num = get_shape(entity_span_start) entity_span_start = tf.reshape(entity_span_start, [batch_size, 2]) entity_span_end = tf.reshape(entity_span_end, [batch_size, 2]) # bert encode [batch_size, seq_len, hidden_size] bert_encode = self.bert(inputs, **kwargs) seq_output = bert_encode[0] # build repr for entities entity_start_repr = tf_gather( seq_output, entity_span_start) # [batch_size, 2, hidden_size] entity_end_repr = tf_gather( seq_output, entity_span_end) # [batch_size, 2, hidden_size] # [batch_size, 2, hidden_size * 2] entity_repr = tf.concat([entity_start_repr, entity_end_repr], -1) # [batch_size, 2, hidden_size] entity_repr = self.project1(entity_repr) entity_repr = self.dropout(entity_repr, training=kwargs.get('training', False)) # fusion of every entity pairs entity_reprs = tf.unstack(entity_repr, axis=1) entity_repr1_r = tf.reshape(entity_reprs[0], [batch_size, self.hidden_size]) entity_repr2_r = tf.reshape(entity_reprs[1], [batch_size, self.hidden_size]) # 1 just entity_pair repr # entity_pair_repr = tf.concat([entity_repr1, entity_repr2], -1) # entity_pair_repr = self.project2(entity_pair_repr) # entity_pair_repr = self.dropout(entity_pair_repr, training=kwargs.get('training', False)) # 2 combine seq and entity_pair reprs # cls_output = bert_encode[1] # entity_pair_repr = tf.concat([cls_output, entity_repr1, entity_repr2], -1) # entity_pair_repr = self.project2(entity_pair_repr) # entity_pair_repr = self.dropout(entity_pair_repr, training=kwargs.get('training', False)) # 3 attention # attention rep for entities entity_repr1 = self.e1_attention(tf.expand_dims(entity_repr1_r, 1), seq_output, seq_output, training=kwargs.get( 'training', False)) entity_repr1 = tf.reshape(entity_repr1, [batch_size, self.hidden_size]) entity_repr1 = entity_repr1 + entity_repr1_r entity_repr2 = self.e2_attention(tf.expand_dims(entity_repr2_r, 1), seq_output, seq_output, training=kwargs.get( 'training', False)) entity_repr2 = tf.reshape(entity_repr2, [batch_size, self.hidden_size]) entity_repr2 = entity_repr2 + entity_repr2_r # relation rep entity_pair_repr_0 = tf.concat( [entity_repr1, entity_repr2, entity_repr1 - entity_repr2], -1) entity_pair_repr = self.project2(entity_pair_repr_0) entity_pair_repr1 = self.dropout(entity_pair_repr, training=kwargs.get( 'training', False)) entity_pair_repr = tf.expand_dims(entity_pair_repr1, 1) # attention rep for relation rep entity_pair_repr = self.attention(entity_pair_repr, seq_output, seq_output, training=kwargs.get( 'training', False)) entity_pair_repr = tf.reshape(entity_pair_repr[0], [batch_size, self.hidden_size]) entity_pair_repr = entity_pair_repr1 + entity_pair_repr entity_pair_repr = tf.concat([entity_pair_repr, bert_encode[1]], -1) entity_pair_repr = self.project3(entity_pair_repr) entity_pair_repr = self.project4(entity_pair_repr) entity_pair_repr = self.project5(entity_pair_repr) entity_pair_repr = self.dropout(entity_pair_repr, training=kwargs.get("training", False)) # classifer logits = self.classifer(entity_pair_repr) return logits
def call(self, inputs, training=False): """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer. cls_index: [optional] position of the classification token if summary_type == 'cls_index', shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states. if summary_type == 'cls_index' and cls_index is None: we take the last token of the sequence as classification token """ if not isinstance(inputs, (dict, tuple, list)): hidden_states = inputs cls_index = None elif isinstance(inputs, (tuple, list)): hidden_states = inputs[0] cls_index = inputs[1] if len(inputs) > 1 else None assert len(inputs) <= 2, "Too many inputs." else: input_ids = inputs.get('input_ids') cls_index = inputs.get('cls_index', None) if self.summary_type == 'last': output = hidden_states[:, -1] elif self.summary_type == 'first': output = hidden_states[:, 0] elif self.summary_type == 'mean': output = tf.mean(hidden_states, axis=1) elif self.summary_type == 'cls_index': hidden_shape = get_shape( hidden_states ) # e.g. [batch, num choices, seq length, hidden dims] if cls_index is None: cls_index = tf.fill( hidden_shape[:-2], hidden_shape[-2] - 1 ) # A tensor full of shape [batch] or [batch, num choices] full of sequence length cls_shape = get_shape(cls_index) if len(cls_shape) <= len(hidden_shape) - 2: cls_index = cls_index[..., tf.newaxis] # else: # cls_index = cls_index[..., tf.newaxis] # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),)) # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2) output = tf.squeeze( output, axis=len(hidden_shape) - 2) # shape of output: (batch, num choices, hidden_size) elif self.summary_type == 'attn': raise NotImplementedError if self.has_first_dropout: output = self.first_dropout(output, training=training) if self.has_summary: output = self.summary(output) if self.has_activation: output = self.activation(output) if self.has_last_dropout: output = self.last_dropout(output, training=training) return output
def call(self, inputs, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, training=False): if isinstance(inputs, (tuple, list)): input_ids = inputs[0] attention_mask = inputs[1] if len(inputs) > 1 else attention_mask mems = inputs[2] if len(inputs) > 2 else mems perm_mask = inputs[3] if len(inputs) > 3 else perm_mask target_mapping = inputs[4] if len(inputs) > 4 else target_mapping token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids input_mask = inputs[6] if len(inputs) > 6 else input_mask head_mask = inputs[7] if len(inputs) > 7 else head_mask inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds assert len(inputs) <= 9, "Too many inputs." elif isinstance(inputs, dict): input_ids = inputs.get('input_ids') attention_mask = inputs.get('attention_mask', attention_mask) mems = inputs.get('mems', mems) perm_mask = inputs.get('perm_mask', perm_mask) target_mapping = inputs.get('target_mapping', target_mapping) token_type_ids = inputs.get('token_type_ids', token_type_ids) input_mask = inputs.get('input_mask', input_mask) head_mask = inputs.get('head_mask', head_mask) inputs_embeds = inputs.get('inputs_embeds', inputs_embeds) assert len(inputs) <= 9, "Too many inputs." else: input_ids = inputs # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end # but we want a unified interface in the library with the batch size on the first dimension # so we move here the first dimension (batch) to the end if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: input_ids = tf.transpose(input_ids, perm=(1, 0)) qlen, bsz = get_shape(input_ids)[:2] elif inputs_embeds is not None: inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2)) qlen, bsz = get_shape(inputs_embeds)[:2] else: raise ValueError( "You have to specify either input_ids or inputs_embeds") token_type_ids = tf.transpose( token_type_ids, perm=(1, 0)) if token_type_ids is not None else None input_mask = tf.transpose( input_mask, perm=(1, 0)) if input_mask is not None else None attention_mask = tf.transpose( attention_mask, perm=(1, 0)) if attention_mask is not None else None perm_mask = tf.transpose( perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None target_mapping = tf.transpose( target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None mlen = get_shape( mems[0])[0] if mems is not None and mems[0] is not None else 0 klen = mlen + qlen dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32 ##### Attention mask # causal attention mask if self.attn_type == 'uni': attn_mask = self.create_mask(qlen, mlen) attn_mask = attn_mask[:, :, None, None] elif self.attn_type == 'bi': attn_mask = None else: raise ValueError('Unsupported attention type: {}'.format( self.attn_type)) # data mask: input mask & perm mask assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) " \ "or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one." if input_mask is None and attention_mask is not None: attention_mask = tf.cast(attention_mask, tf.float32) input_mask = 1. - attention_mask if input_mask is not None and perm_mask is not None: data_mask = input_mask[None] + perm_mask elif input_mask is not None and perm_mask is None: data_mask = input_mask[None] elif input_mask is None and perm_mask is not None: data_mask = perm_mask else: data_mask = None if data_mask is not None: # all mems can be attended to mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz], dtype=dtype_float) data_mask = tf.concat([mems_mask, data_mask], axis=1) if attn_mask is None: attn_mask = data_mask[:, :, :, None] else: attn_mask += data_mask[:, :, :, None] if attn_mask is not None: attn_mask = tf.cast(attn_mask > 0, dtype=dtype_float) if attn_mask is not None: non_tgt_mask = -tf.eye(qlen, dtype=dtype_float) non_tgt_mask = tf.concat( [tf.zeros([qlen, mlen], dtype=dtype_float), non_tgt_mask], axis=-1) non_tgt_mask = tf.cast( (attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=dtype_float) else: non_tgt_mask = None ##### Word embeddings and prepare h & g hidden states if inputs_embeds is not None: word_emb_k = inputs_embeds else: word_emb_k = self.word_embedding(input_ids) output_h = self.dropout(word_emb_k, training=training) if target_mapping is not None: word_emb_q = tf.tile(self.mask_emb, [tf.shape(target_mapping)[0], bsz, 1]) # else: # We removed the inp_q input which was same as target mapping # inp_q_ext = inp_q[:, :, None] # word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k output_g = self.dropout(word_emb_q, training=training) else: output_g = None ##### Segment embedding if token_type_ids is not None: # Convert `token_type_ids` to one-hot `seg_mat` mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32) cat_ids = tf.concat([mem_pad, token_type_ids], 0) # `1` indicates not in the same segment [qlen x klen x bsz] seg_mat = tf.cast( tf.logical_not( tf.equal(token_type_ids[:, None], cat_ids[None, :])), tf.int32) seg_mat = tf.one_hot(seg_mat, 2, dtype=dtype_float) else: seg_mat = None ##### Positional encoding pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float) pos_emb = self.dropout(pos_emb, training=training) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] if head_mask is not None: if head_mask.dim() == 1: head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze( 0).unsqueeze(0) head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1) elif head_mask.dim() == 2: head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1) head_mask = head_mask.to(dtype=next(self.parameters( )).dtype) # switch to fload if need + fp16 compatibility else: head_mask = [None] * self.n_layer new_mems = () if mems is None: mems = [None] * len(self.layer) attentions = [] hidden_states = [] for i, layer_module in enumerate(self.layer): # cache new mems if self.mem_len is not None and self.mem_len > 0 and self.output_past: new_mems = new_mems + (self.cache_mem(output_h, mems[i]), ) if self.output_hidden_states: hidden_states.append(( output_h, output_g) if output_g is not None else output_h) outputs = layer_module([ output_h, output_g, non_tgt_mask, attn_mask, pos_emb, seg_mat, mems[i], target_mapping, head_mask[i] ], training=training) output_h, output_g = outputs[:2] if self.output_attentions: attentions.append(outputs[2]) # Add last hidden state if self.output_hidden_states: hidden_states.append(( output_h, output_g) if output_g is not None else output_h) output = self.dropout(output_g if output_g is not None else output_h, training=training) # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) output = tf.transpose(output, perm=(1, 0, 2)) first_pooling = self.pooler(output) outputs = (output, first_pooling) if self.mem_len is not None and self.mem_len > 0 and self.output_past: outputs = outputs + (new_mems, ) if self.output_hidden_states: if output_g is not None: hidden_states = tuple( tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs) else: hidden_states = tuple( tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states) outputs = outputs + (hidden_states, ) if self.output_attentions: attentions = tuple( tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) outputs = outputs + (attentions, ) return outputs # outputs, (new_mems), (hidden_states), (attentions)
def split_heads(self, x): x_shape = get_shape(x) new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head] x = tf.reshape(x, new_x_shape) return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
def merge_heads(self, x): x = tf.transpose(x, [0, 2, 1, 3]) x_shape = get_shape(x) new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]] return tf.reshape(x, new_x_shape)
def call( self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, training=False, **kwargs, ): inputs = input_processing( func=self.call, config=self.config, input_ids=input_ids, past=past, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, kwargs_call=kwargs, ) if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif inputs["input_ids"] is not None: input_shape = get_shape(inputs["input_ids"]) inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]]) elif inputs["inputs_embeds"] is not None: input_shape = get_shape(inputs["inputs_embeds"])[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs["past"] is None: past_length = 0 inputs["past"] = [None] * len(self.h) else: past_length = get_shape(inputs["past"][0][0])[-2] if inputs["position_ids"] is None: inputs["position_ids"] = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :] if inputs["attention_mask"] is not None: # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. inputs["attention_mask"] = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :] # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32) inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0 else: inputs["attention_mask"] = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] if inputs["head_mask"] is not None: raise NotImplementedError else: inputs["head_mask"] = [None] * self.num_hidden_layers # head_mask = tf.constant([0] * self.num_hidden_layers) inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, get_shape(inputs["position_ids"])[-1]]) if inputs["inputs_embeds"] is None: inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding") position_embeds = self.wpe(inputs["position_ids"]) if inputs["token_type_ids"] is not None: inputs["token_type_ids"] = tf.reshape( inputs["token_type_ids"], [-1, get_shape(inputs["token_type_ids"])[-1]] ) token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding") else: token_type_embeds = 0 position_embeds = tf.cast(position_embeds, dtype=inputs["inputs_embeds"].dtype) token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype) hidden_states = inputs["inputs_embeds"] + position_embeds + token_type_embeds hidden_states = self.drop(hidden_states, training=inputs["training"]) output_shape = input_shape + [get_shape(hidden_states)[-1]] presents = () if inputs["use_cache"] else None all_attentions = () if inputs["output_attentions"] else None all_hidden_states = () if inputs["output_hidden_states"] else None for i, (block, layer_past) in enumerate(zip(self.h, inputs["past"])): if inputs["output_hidden_states"]: all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) outputs = block( hidden_states, layer_past, inputs["attention_mask"], inputs["head_mask"][i], inputs["use_cache"], inputs["output_attentions"], training=inputs["training"], ) hidden_states, present = outputs[:2] if inputs["use_cache"]: presents = presents + (present,) if inputs["output_attentions"]: all_attentions = all_attentions + (outputs[2],) hidden_states = self.ln_f(hidden_states) hidden_states = tf.reshape(hidden_states, output_shape) pooled_output = self.pooler(hidden_states) # Add last hidden state if inputs["output_hidden_states"]: all_hidden_states = all_hidden_states + (hidden_states,) if inputs["output_attentions"]: # let the number of heads free (-1) so we can extract attention even after head pruning attention_output_shape = input_shape[:-1] + [-1] + get_shape(all_attentions[0])[-2:] all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) # if not inputs["return_dict"]: # return tuple(v for v in [hidden_states, pooled_output, presents, all_hidden_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, pooled_output, presents, all_hidden_states, all_attentions] if v is not None) # return outputs # hidden_states, pooled_output, presents, all_hidden_states, all_attentions
def call( self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False, ): if isinstance(inputs, (tuple, list)): input_ids = inputs[0] attention_mask = inputs[1] if len(inputs) > 1 else attention_mask token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids head_mask = inputs[4] if len(inputs) > 4 else head_mask inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds assert len(inputs) <= 6, "Too many inputs." elif isinstance(inputs, dict): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) token_type_ids = inputs.get("token_type_ids", token_type_ids) position_ids = inputs.get("position_ids", position_ids) head_mask = inputs.get("head_mask", head_mask) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) assert len(inputs) <= 6, "Too many inputs." else: input_ids = inputs if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: input_shape = get_shape(input_ids) elif inputs_embeds is not None: input_shape = get_shape(inputs_embeds)[:-1] else: raise ValueError( "You have to specify either input_ids or inputs_embeds") if attention_mask is None: attention_mask = tf.fill(input_shape, 1) if token_type_ids is None: token_type_ids = tf.fill(input_shape, 0) extended_attention_mask = self.get_extended_attention_mask( attention_mask, input_shape) head_mask = self.get_head_mask(head_mask) hidden_states = self.embeddings( [input_ids, position_ids, token_type_ids, inputs_embeds], training=training) if hasattr(self, "embeddings_project"): hidden_states = self.embeddings_project(hidden_states, training=training) encoder_outputs = self.encoder( [hidden_states, extended_attention_mask, head_mask], training=training) sequence_output = encoder_outputs[0] pooled_output = sequence_output[:, 0] outputs = ( sequence_output, pooled_output, ) + encoder_outputs[1:] return outputs
def call(self, inputs, token_type_ids=None, attention_mask=None, position_ids=None, task_type_ids=None, head_mask=None, training=False): if isinstance(inputs, (tuple, list)): input_ids = inputs[0] attention_mask = inputs[1] if len(inputs) > 1 else attention_mask token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids task_ids = inputs[4] if len(inputs) > 4 else task_type_ids head_mask = inputs[4] if len(inputs) > 5 else head_mask assert len(inputs) <= 6, "Too many inputs." elif isinstance(inputs, dict): input_ids = inputs.get('input_ids') attention_mask = inputs.get('attention_mask', attention_mask) token_type_ids = inputs.get('token_type_ids', token_type_ids) position_ids = inputs.get('position_ids', position_ids) task_type_ids = inputs.get("task_type_ids", task_type_ids) head_mask = inputs.get('head_mask', head_mask) assert len(inputs) <= 6, "Too many inputs." else: input_ids = inputs if attention_mask is None: attention_mask = tf.fill(get_shape(input_ids), 1) if token_type_ids is None: token_type_ids = tf.fill(get_shape(input_ids), 0) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] if not head_mask is None: raise NotImplementedError else: head_mask = [None] * self.num_hidden_layers # head_mask = tf.constant([0] * self.num_hidden_layers) embedding_output = self.embeddings( [input_ids, position_ids, token_type_ids, task_type_ids], training=training) encoder_outputs = self.encoder( [embedding_output, extended_attention_mask, head_mask], training=training) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) outputs = ( sequence_output, pooled_output, ) + encoder_outputs[ 1:] # add hidden_states and attentions if they are here return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
def call(self, query, key, value, attention_mask=None, head_mask=None, training=False): batch_size = get_shape(query)[0] mixed_query_layer = self.query(query) mixed_key_layer = self.key(key) mixed_value_layer = self.value(value) query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) # (batch size, num_heads, seq_len_q, seq_len_k) attention_shape = get_shape(attention_scores) from_seq_length = attention_shape[2] to_seq_length = attention_shape[3] if self.use_relative_position: max_relative_position = 64 relations_keys = generate_relative_positions_embeddings( to_seq_length, self.attention_head_size, max_relative_position, "relative_positions_keys", cache=False) # query_layer_t is [F, B, N, H] query_layer_t = tf.transpose(query_layer, [2, 0, 1, 3]) # query_layer_r is [F, B * N, H] query_layer_r = tf.reshape(query_layer_t, [from_seq_length, batch_size * self.num_attention_heads, self.attention_head_size]) # key_position_scores is [F, B * N, F|T] key_position_scores = tf.matmul(query_layer_r, relations_keys, transpose_b=True) # key_position_scores_r is [F, B , N, F|T] key_position_scores_r = tf.reshape(key_position_scores, [from_seq_length, batch_size, self.num_attention_heads, from_seq_length]) # key_position_scores_r_t is [B, N, F, F|T] key_position_scores_r_t = tf.transpose(key_position_scores_r, [1, 2, 0, 3]) attention_scores += key_position_scores_r_t dk = tf.cast(get_shape(key_layer)[-1], tf.float32) # scale attention_scores attention_scores = attention_scores / tf.math.sqrt(dk) if attention_mask is not None: attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = tf.nn.softmax(attention_scores, axis=-1) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs, training=training) # Mask heads if we want to if head_mask is not None: attention_probs = attention_probs * head_mask context_layer = tf.matmul(attention_probs, value_layer) context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) context_layer = tf.reshape(context_layer, (batch_size, -1, self.all_head_size)) # (batch_size, seq_len_q, all_head_size) outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) return outputs