def call(self, inputs, **kwargs): source, target = inputs target_shape = keras.backend.shape(target) if keras.backend.image_data_format() == 'channels_first': source = backend.transpose(source, (0, 2, 3, 1)) output = backend.resize_images(source, (target_shape[2], target_shape[3]), method='nearest') output = backend.transpose(output, (0, 3, 1, 2)) return output else: return backend.resize_images(source, (target_shape[1], target_shape[2]), method='nearest')
def get_attn_mask_func(inputs): input_mask = inputs batch_size = K.shape(input_mask)[0] target_len = K.shape(input_mask)[1] # 512, ? input_mask_trans = K.transpose(input_mask) # 1, 512, ? # data_mask = input_mask[None] data_mask = K.expand_dims(input_mask_trans, axis=0) # ?, 0, 2 # mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz], dtype=tf_float) mems_mask = keras.layers.Lambda( lambda x: tf.zeros([1, 0, x], dtype=tf.float32))(batch_size) # 1, 512, 2 # data_mask = tf.concat([mems_mask, data_mask], 1) data_mask = K.concatenate([mems_mask, data_mask], axis=1) # attn_mask = data_mask[:, :, :, None] attn_mask = K.expand_dims(data_mask, axis=-1) # attn_mask = tf.cast(attn_mask > 0, dtype=tf_float) attn_mask = keras.layers.Lambda( lambda x: tf.cast(x > 0, dtype=tf.float32))(attn_mask) # non_tgt_mask = -tf.eye(512, dtype=tf.float32) non_tgt_mask = keras.layers.Lambda(lambda x: -tf.eye(x, dtype=tf.float32))( target_len) # non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=tf_float),non_tgt_mask], axis=-1) tmp = keras.layers.Lambda(lambda x: tf.zeros([x, 0], dtype=tf.float32))( target_len) non_tgt_mask = K.concatenate([tmp, non_tgt_mask], axis=-1) # non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=tf_float) non_tgt_mask = K.expand_dims(non_tgt_mask, axis=-1) non_tgt_mask = K.expand_dims(non_tgt_mask, axis=-1) tmp2 = K.greater(attn_mask + non_tgt_mask, 0) attn_mask = K.cast(tmp2, tf.float32) return attn_mask
def build_xlnet_for_tf_estimator(inputs, num_token, num_layer, num_head, embedding_dim, attention_head_dim, feed_forward_dim, target_len, is_training, memory_len=None, dropout=0.0, attention_dropout=0.0, attention_type=None, shared_biases=True): input_ids, input_mask, segment_ids, cls_index, \ p_mask, start_positions, end_positions, is_impossible = inputs attn_mask = get_attn_mask(input_mask) input_ids_trans = keras.layers.Lambda(lambda x: K.transpose(x))(input_ids) token_embed = keras.layers.Embedding(input_dim=num_token, output_dim=embedding_dim, name='Embed-Token')(input_ids_trans) token_embed_dropout = keras.layers.Dropout(rate=dropout, name='Embed-Token-Dropout')( token_embed, training=is_training) pos_emb = get_pos_emb([input_ids_trans, token_embed]) pos_emb = keras.layers.Dropout(rate=dropout)(pos_emb, training=is_training) initializer = keras.initializers.get('normal') initializer.__setattr__("stddev", 0.02) segment_ids_trans = keras.layers.Lambda(lambda x: K.transpose(x))( segment_ids) segment_mat, segment_embed = RelativeSegmentEmbedding( num_layer=num_layer, num_head=num_head, attention_dim=attention_head_dim, initializer=initializer, name='Embed-Segment', )(segment_ids_trans) r_w_bias, r_r_bias, r_s_bias = RelativeBias( num_layer=num_layer, num_head=num_head, attention_head_dim=attention_head_dim, bias_initializer=initializer, name='Relative-Bias', )(input_ids_trans) content_output = token_embed_dropout if FLAGS.short_cut_fake: attn_mask = tf.constant(1.0, shape=[512, 512, 1, 1], dtype=np.float32) segment_mat = tf.constant(1.0, shape=[512, 512, 1, 2], dtype=np.float32) pos_emb = tf.constant(1.0, shape=[1024, 1, 1024], dtype=np.float32) if FLAGS.short_cut_fuse: attn_mask_flat = tf.reshape(attn_mask, [-1]) segment_mat_flat = tf.reshape(segment_mat, [-1]) segment_embed_flat = tf.reshape(segment_embed, [-1]) pos_emb_flat = tf.reshape(pos_emb, [-1]) r_w_bias_flat = tf.reshape(r_w_bias, [-1]) r_r_bias_flat = tf.reshape(r_r_bias, [-1]) r_s_bias_flat = tf.reshape(r_s_bias, [-1]) fused = tf.concat([attn_mask_flat, segment_mat_flat, segment_embed_flat, \ pos_emb_flat, r_w_bias_flat, r_r_bias_flat, r_s_bias_flat], 0) for i in range(num_layer): attention = RelativeMultiHeadAttention( num_head=num_head, attention_head_dim=attention_head_dim, embedding_dim=embedding_dim, dropout=dropout, dropatt=attention_dropout, is_training=is_training, initializer=initializer, name='Attention-{}'.format(i + 1), ) attention_add = tf.keras.layers.Add( name='Attention-Residual-{}'.format(i + 1)) attention_layer_norm = LayerNormalization( name='Attention-Normal-{}'.format(i + 1)) feed_forward = FeedForward(feed_forward_dim=feed_forward_dim, embedding_dim=embedding_dim, dropout_rate=dropout, kernel_initializer=initializer, activation=gelu, name='FeedForward-{}'.format(i + 1)) feed_forward_add = tf.keras.layers.Add( name='FeedForward-Residual-{}'.format(i + 1)) feed_forward_layer_norm = LayerNormalization( name='FeedForward-Normal-{}'.format(i + 1)) segment_embed_i = keras.layers.Lambda(lambda x: x[i])(segment_embed) r_w_bias_i = keras.layers.Lambda(lambda x: x[i])(r_w_bias) r_r_bias_i = keras.layers.Lambda(lambda x: x[i])(r_r_bias) r_s_bias_i = keras.layers.Lambda(lambda x: x[i])(r_s_bias) if FLAGS.short_cut_fuse: attn_mask_flat, segment_mat_flat, segment_embed_flat, \ pos_emb_flat, r_w_bias_flat, r_r_bias_flat, r_s_bias_flat = \ tf.split(fused, [512*512*1, 512*512*2, 24*2*1024, \ 1024*1024, 24*1024, 24*1024, 24*1024], 0) attn_mask = tf.reshape(attn_mask_flat, [512, 512, 1, 1]) segment_mat = tf.reshape(segment_mat_flat, [512, 512, 1, 2]) segment_embed = tf.reshape(segment_embed_flat, [24, 2, 16, 64]) pos_emb = tf.reshape(pos_emb_flat, [1024, 1, 1024]) r_w_bias = tf.reshape(r_w_bias_flat, [24, 16, 64]) r_r_bias = tf.reshape(r_r_bias_flat, [24, 16, 64]) r_s_bias = tf.reshape(r_s_bias_flat, [24, 16, 64]) print(attn_mask, segment_mat, segment_embed, pos_emb, r_w_bias, r_r_bias, r_s_bias) def _build_output(query): attention_input = query _output = attention([ query, pos_emb, segment_embed_i, segment_mat, r_w_bias_i, r_r_bias_i, r_s_bias_i, attn_mask ]) _output = attention_add([attention_input, _output]) _output = attention_layer_norm(_output) feed_forward_input = keras.layers.Lambda(lambda x: K.identity(x))( _output) _output = feed_forward(_output, training=is_training) _output = feed_forward_add([feed_forward_input, _output]) _output = feed_forward_layer_norm(_output) return _output content_output = _build_output(content_output) output = keras.layers.Dropout(rate=dropout)(content_output, training=is_training) xlnet_loss = XLnetLoss(d_model=embedding_dim, seq_len=target_len, kernel_initializer=initializer, name="XLNET_LOSS")([ cls_index, start_positions, end_positions, is_impossible, p_mask, output ]) return xlnet_loss
def __init__(self, num_token, num_layer, num_head, embedding_dim, attention_head_dim, feed_forward_dim, target_len, is_training, memory_len=None, dropout=0.0, attention_dropout=0.0, attention_type=None, shared_biases=True): self.num_layer = num_layer self.dropout = dropout self.attention_dropout = attention_dropout self.token_embed = keras.layers.Embedding(input_dim=num_token, output_dim=embedding_dim, name='Embed-Token') initializer = keras.initializers.get('normal') initializer.__setattr__("stddev", 0.02) self.segment_ids_trans = keras.layers.Lambda(lambda x: K.transpose(x)) self.segment_mat_embed = RelativeSegmentEmbedding( num_layer=num_layer, num_head=num_head, attention_dim=attention_head_dim, initializer=initializer, name='Embed-Segment') self.relative_bias = RelativeBias( num_layer=num_layer, num_head=num_head, attention_head_dim=attention_head_dim, bias_initializer=initializer, name='Relative-Bias') self.attention = [] self.attention_add = [] self.attention_layer_norm = [] self.feed_forward = [] self.feed_forward_add = [] self.feed_forward_layer_norm = [] for i in range(num_layer): self.attention.append( RelativeMultiHeadAttention( num_head=num_head, attention_head_dim=attention_head_dim, embedding_dim=embedding_dim, dropout=dropout, dropatt=attention_dropout, is_training=is_training, initializer=initializer, name='Attention-{}'.format(i + 1), )) self.attention_add.append( tf.keras.layers.Add(name='Attention-Residual-{}'.format(i + 1))) self.attention_layer_norm.append( LayerNormalization(name='Attention-Normal-{}'.format(i + 1))) self.feed_forward.append( FeedForward(feed_forward_dim=feed_forward_dim, embedding_dim=embedding_dim, dropout_rate=dropout, kernel_initializer=initializer, activation=gelu, name='FeedForward-{}'.format(i + 1))) self.feed_forward_add.append( tf.keras.layers.Add(name='FeedForward-Residual-{}'.format(i + 1))) self.feed_forward_layer_norm.append( LayerNormalization(name='FeedForward-Normal-{}'.format(i + 1))) self.xlnet_loss = XLnetLoss(d_model=embedding_dim, seq_len=target_len, kernel_initializer=initializer, name="XLNET_LOSS")
def build(self, inputs, dep_outputs=None, is_training=True): # pipeline num device devices = cluster_utils.get_pipeline_devices(FLAGS.pipeline_device_num) ndev = len(devices) # embedding + dropout + ... + dropout nstage = self.num_layer + 3 def calc_device(i): # original stage fn idx = int((i + 2) / ((nstage + 1) / ndev + 1)) split_layer_id = 13 if FLAGS.num_layer == 24 else 19 # For XLNet-24: # stage fn 1: Forward-11 in stage0 # stage fn 2: Forward-10 in stage0 # For XLNet-36: # stage fn: Forward 17 in stage0 # 13:(11:13) for xlnet-24 # 19:(17:19) for xlnet-36 if i < split_layer_id: return 0 else: return 1 return idx dep = None device_idx = 0 if dep_outputs is not None and dep_outputs[device_idx] is not None: dep = dep_outputs[device_idx] \ if isinstance(dep_outputs[device_idx], list) else [dep_outputs[device_idx]] with tf.control_dependencies(dep), tf.device(devices[device_idx]): input_ids, input_mask, segment_ids, cls_index, \ p_mask, start_positions, end_positions, is_impossible = inputs # 1MB, [512,512,1,1] attn_mask = get_attn_mask(input_mask) input_ids_trans = keras.layers.Lambda(lambda x: K.transpose(x))( input_ids) token_embed = self.token_embed(input_ids_trans) segment_ids_trans = self.segment_ids_trans(segment_ids) # 2MB [512,512,1,2] 192KB [24, 2, 16, 64] segment_mat, segment_embed = self.segment_mat_embed( segment_ids_trans) # 3*96KB [24, 16, 64] r_w_bias, r_r_bias, r_s_bias = self.relative_bias(input_ids_trans) token_embed_dropout = keras.layers.Dropout( rate=self.dropout, name='Embed-Token-Dropout')(token_embed, training=is_training) content_output = token_embed_dropout pos_emb = get_pos_emb([input_ids_trans, token_embed]) # 4MB [1024, 1, 1024] pos_emb = keras.layers.Dropout(rate=self.dropout)( pos_emb, training=is_training) if FLAGS.short_cut_fake: attn_mask = tf.constant(1.0, shape=[512, 512, 1, 1], dtype=np.float32) segment_mat = tf.constant(1.0, shape=[512, 512, 1, 2], dtype=np.float32) pos_emb = tf.constant(1.0, shape=[1024, 1, 1024], dtype=np.float32) if FLAGS.short_cut_fuse: attn_mask_flat = tf.reshape(attn_mask, [-1]) segment_mat_flat = tf.reshape(segment_mat, [-1]) segment_embed_flat = tf.reshape(segment_embed, [-1]) pos_emb_flat = tf.reshape(pos_emb, [-1]) r_w_bias_flat = tf.reshape(r_w_bias, [-1]) r_r_bias_flat = tf.reshape(r_r_bias, [-1]) r_s_bias_flat = tf.reshape(r_s_bias, [-1]) fused = tf.concat([attn_mask_flat, segment_mat_flat, segment_embed_flat, \ pos_emb_flat, r_w_bias_flat, r_r_bias_flat, r_s_bias_flat], 0) def _build_output(query, i, pos_emb, segment_mat, segment_embed, \ attn_mask, r_w_bias, r_r_bias, r_s_bias): segment_embed_i = keras.layers.Lambda(lambda x: x[i])( segment_embed) r_w_bias_i = keras.layers.Lambda(lambda x: x[i])(r_w_bias) r_r_bias_i = keras.layers.Lambda(lambda x: x[i])(r_r_bias) r_s_bias_i = keras.layers.Lambda(lambda x: x[i])(r_s_bias) attention_input = query _output = self.attention[i]([ query, pos_emb, segment_embed_i, segment_mat, r_w_bias_i, r_r_bias_i, r_s_bias_i, attn_mask ]) _output = self.attention_add[i]([attention_input, _output]) _output = self.attention_layer_norm[i](_output) feed_forward_input = keras.layers.Lambda(lambda x: K.identity(x))( _output) _output = self.feed_forward[i](_output, training=is_training) _output = self.feed_forward_add[i]([feed_forward_input, _output]) _output = self.feed_forward_layer_norm[i](_output) return _output # output list of all stages stage_outputs = [] # previous output, for the first stage, it is input_ids prev_output = input_ids # previous device index, init value is 0 prev_device_idx = 0 for i in range(self.num_layer): layer = i + 2 device_idx = calc_device(layer) dep = None boundary = False if device_idx != prev_device_idx: # current layer cross the stage boundary = True if dep_outputs is not None and dep_outputs[ device_idx] is not None: dep = dep_outputs[device_idx] \ if isinstance(dep_outputs[device_idx], list) else [dep_outputs[device_idx]] stage_outputs.append(prev_output) prev_device_idx = device_idx with tf.control_dependencies(dep), tf.device(devices[device_idx]): if boundary: if FLAGS.short_cut_fake: attn_mask = tf.constant(1.0, shape=[512, 512, 1, 1], dtype=np.float32) segment_mat = tf.constant(1.0, shape=[512, 512, 1, 2], dtype=np.float32) pos_emb = tf.constant(1.0, shape=[1024, 1, 1024], dtype=np.float32) if FLAGS.short_cut_fuse: num_layers = FLAGS.num_layer attn_mask_flat, segment_mat_flat, segment_embed_flat, \ pos_emb_flat, r_w_bias_flat, r_r_bias_flat, r_s_bias_flat = \ tf.split(fused, [512*512*1, 512*512*2, num_layers*2*1024, \ 1024*1024, num_layers*1024, num_layers*1024, num_layers*1024], 0) attn_mask = tf.reshape(attn_mask_flat, [512, 512, 1, 1]) segment_mat = tf.reshape(segment_mat_flat, [512, 512, 1, 2]) segment_embed = tf.reshape(segment_embed_flat, [num_layers, 2, 16, 64]) pos_emb = tf.reshape(pos_emb_flat, [1024, 1, 1024]) r_w_bias = tf.reshape(r_w_bias_flat, [num_layers, 16, 64]) r_r_bias = tf.reshape(r_r_bias_flat, [num_layers, 16, 64]) r_s_bias = tf.reshape(r_s_bias_flat, [num_layers, 16, 64]) print(attn_mask, segment_mat, segment_embed, pos_emb, r_w_bias, r_r_bias, r_s_bias) content_output = _build_output(content_output, i, pos_emb, segment_mat, \ segment_embed, attn_mask, r_w_bias, r_r_bias, r_s_bias) prev_output = content_output # current layer cross the stage layer = self.num_layer + 2 device_idx = calc_device(layer) dep = None if device_idx != prev_device_idx: # current layer cross the stage if dep_outputs is not None and dep_outputs[device_idx] is not None: dep = dep_outputs[device_idx] \ if isinstance(dep_outputs[device_idx], list) else [dep_outputs[device_idx]] stage_outputs.append(prev_output) with tf.control_dependencies(dep), tf.device(devices[device_idx]): output = keras.layers.Dropout(rate=self.dropout)( content_output, training=is_training) xlnet_loss = self.xlnet_loss([ cls_index, start_positions, end_positions, is_impossible, p_mask, output ]) stage_outputs.append(xlnet_loss) return xlnet_loss, stage_outputs
def call(self, inputs, **kwargs): cls_index, start_positions, end_positions, is_impossible, p_mask, output = inputs # output 512, ?, 1024 if len(start_positions.shape) == 1: start_positions = K.expand_dims(start_positions, axis=-1) cls_index = K.expand_dims(cls_index, axis=-1) end_positions = K.expand_dims(end_positions, axis=-1) is_impossible = K.expand_dims(is_impossible, axis=-1) cls_index = K.squeeze(cls_index, -1) start_positions = K.squeeze(start_positions, -1) end_positions = K.squeeze(end_positions, -1) is_impossible = K.squeeze(is_impossible, -1) # logit of the start position start_logits = self.dense(output) start_logits = K.transpose(K.squeeze(start_logits, -1)) start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask start_log_probs = keras.layers.Lambda(lambda x: tf.nn.log_softmax(x, -1))(start_logits_masked) # logit of the end position start_positions = K.cast(start_positions, dtype=tf.int32) # tart_index_1 = K.one_hot(start_positions, self.max_seq_length) start_index = keras.layers.Lambda(lambda x: tf.one_hot(x[0], x[1], dtype=tf.float32))( [start_positions, self.max_seq_length]) # start_features = tf.einsum("lbh,bl->bh", output, start_index) start_features = keras.layers.Lambda(lambda x: tf.einsum("lbh,bl->bh", x[0], x[1]))([output, start_index]) start_features = K.expand_dims(start_features, 0) start_features = K.tile(start_features, [self.max_seq_length, 1, 1]) tmp_concat = K.concatenate([output, start_features], axis=-1) end_logits = self.dense_0(tmp_concat) #end_logits = tf.contrib.layers.layer_norm(end_logits,begin_norm_axis=-1) end_logits = self.layer_norm(end_logits) end_logits = self.dense_1(end_logits) end_logits = K.transpose(K.squeeze(end_logits, -1)) end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask # end_log_probs = tf.nn.log_softmax(end_logits_masked, -1) end_log_probs = keras.layers.Lambda(lambda x: tf.nn.log_softmax(x, -1))(end_logits_masked) start_loss = - K.sum(start_log_probs * start_index, axis=-1) start_loss = K.mean(start_loss) end_positions = K.cast(end_positions, dtype=tf.int32) # end_index = K.one_hot(end_positions_squeeze, self.max_seq_length) end_index = keras.layers.Lambda(lambda x: tf.one_hot(x[0], x[1], dtype=tf.float32))( [end_positions, self.max_seq_length]) end_loss = - K.sum(end_log_probs * end_index, axis=-1) end_loss = K.mean(end_loss) total_loss = (start_loss + end_loss) * 0.5 # an additional layer to predict answerability cls_index = K.cast(cls_index, dtype=tf.int32) # cls_index = K.one_hot(cls_index, self.max_seq_length) cls_index = keras.layers.Lambda(lambda x: tf.one_hot(x[0], x[1], dtype=tf.float32))( [cls_index, self.max_seq_length]) # cls_feature = tf.einsum("lbh,bl->bh", output, cls_index) cls_feature = keras.layers.Lambda(lambda x: tf.einsum("lbh,bl->bh", x[0], x[1]))([output, cls_index]) # start_p = tf.nn.softmax(start_logits_masked, axis=-1, name="softmax_start") start_p = keras.layers.Lambda(lambda x: tf.nn.softmax(x, axis=-1))(start_logits_masked) # start_feature = tf.einsum("lbh,bl->bh", output, start_p) start_feature = keras.layers.Lambda(lambda x: tf.einsum("lbh,bl->bh", x[0], x[1]))([output, start_p]) # ans_feature = tf.concat([start_feature, cls_feature], -1) ans_feature = K.concatenate([start_feature, cls_feature], -1) ans_feature = self.dense_0_1(ans_feature) ans_feature = keras.layers.Dropout(rate=0.1)(ans_feature, training=True) cls_logits = self.dense_1_1(ans_feature) cls_logits = K.squeeze(cls_logits, -1) is_impossible = K.reshape(is_impossible, [-1]) regression_loss = keras.layers.Lambda(lambda x: tf.nn.sigmoid_cross_entropy_with_logits(labels=x[0], logits=x[1]))( [is_impossible, cls_logits]) regression_loss = K.mean(regression_loss) total_loss += regression_loss * 0.5 self.add_loss(total_loss, inputs=True) return total_loss