def logits_for_sketch_prediction(decoder_states, cls_state, num_channel, hn=None, act_name="relu", wd=0., keep_prob=1.0, is_train=None, compress_mask=None, scope=None): compressing = not isinstance(compress_mask, type(None)) hn = hn or get_shape_list(decoder_states)[-1] with tf.variable_scope(scope or "logits_for_sketch_index"): if compressing: new_decoder_states, _, rev_d = compress_seq_wrt_mask( decoder_states, compress_mask) else: new_decoder_states = decoder_states rev_d = None map_part1 = bn_dense_layer_v2(new_decoder_states, hn, True, 0., "map_part1", "linear", False, wd, keep_prob, is_train) map_part2_pre = bn_dense_layer_v2(cls_state, hn, False, 0., "map_part2_pre", "linear", False, wd, keep_prob, is_train) map_part2 = tf.tile(tf.expand_dims(map_part2_pre, axis=1), [1, get_shape_list(map_part1)[1], 1]) map_res = act_name2fn(act_name)(map_part1 + map_part2) logits = bn_dense_layer_v2(map_res, num_channel, True, 0., "logits", "linear", False, wd, keep_prob, is_train) if compressing: logits = decompress_seq_wrt_mask(logits, rev_d) return logits
def get_word_level_split(params, input_pos_ids, wordpiece_idx, input_mask, sll, pl): # bs,sl,pl bs, sl = get_shape_list(input_pos_ids) higher_dim = len(get_shape_list(params)) > 2 extra_dims = get_shape_list(params)[2:] if higher_dim else [] # tf.tile(tf.expand_dims(tf.expand_dims(tf.range(bs), 1), 2), [1, sll, pl]) bs_idxs = tf.tile(tf.expand_dims(tf.range(bs), 1), [1, sl]) data_coord = tf.stack([bs_idxs, input_pos_ids, wordpiece_idx], -1) # [bs, sl, 3] # mask input_pos_ids and wordpiece_idx for -1 mask_reversed_int = tf.cast(tf.logical_not(input_mask), tf.int32) data_coord = mask_v3(data_coord, input_mask, high_dim=True) + tf.stack( [ mask_reversed_int * bs, mask_reversed_int * sll, mask_reversed_int * pl, ], axis=-1) # params's dtype check is_bool = (params.dtype == tf.bool) outputs = tf.scatter_nd( indices=data_coord, # [bs, sl, 3] updates=params if not is_bool else tf.cast(params, tf.int32), # [bs,sl] shape=[bs + 1, sll + 1, pl + 1] + extra_dims) if is_bool: outputs = tf.cast(outputs, tf.bool) outputs = outputs[:-1, :-1, :-1] return outputs
def cross_attn_mask_generation(from_mask, to_mask, mutual=True, head_num=None, name=None): """ :param from_mask: 2-D Tensor, [bs,slf] :param to_mask: 2-D Tensor, [bs,slt] :param mutual: :param head_num :param name: :return: 3D Tensor """ with tf.name_scope(name or 'attention_mask_generation'): bs, slf = get_shape_list(from_mask, 2)[:2] slt = get_shape_list(to_mask, 2)[1] if mutual: res_mask = tf.cast( # [bs,slf,slt] tf.expand_dims(tf.cast(from_mask, tf.int32), 2) * tf.expand_dims(tf.cast(to_mask, tf.int32), 1), tf.bool ) else: res_mask = tf.tile(tf.expand_dims(to_mask, 1), [1, slf, 1]) # [bs,slt] -> [bs,slf,slt] if isinstance(head_num, int): res_mask = tf.expand_dims(res_mask, 1) tile_multiples = [1] * len(get_shape_list(res_mask)) tile_multiples[1] = head_num res_mask = tf.tile(res_mask, tile_multiples) return res_mask
def _build_network_all_sketch_logits( self, decoder_states, encoder_states_for_decoder, encoder_mask, cls_state_predicate, cls_state_type, use_mask=False ): bs = get_shape_list(decoder_states)[0] if use_mask: entity_mask = tf.not_equal(self.sketch_entity, -1) predicate_mask = tf.not_equal(self.sketch_predicate, 0) type_mask = tf.not_equal(self.sketch_type, 0) num_mask = tf.not_equal(self.sketch_num, -1) else: entity_mask = None predicate_mask = None type_mask = None num_mask = None # bs,sl -----modify the last token to False encoder_wo_cls = tf.concat([ # [bs,sl] encoder_mask[:, 1:], # [bs,sl-1] tf.cast(tf.zeros([get_shape_list(encoder_mask)[0], 1], tf.int32), tf.bool) # [bs, 1] ], -1) logits_sketch_entity_pre = logits_for_sketch_index( # bs,dsl,esl decoder_states, encoder_states_for_decoder, self.cfg["hn"], 0., 1 - self.cfg["clf_dropout"], self.is_training, compress_mask=entity_mask, scope="logits_sketch_entity_pre" ) logits_sketch_entity = mask_v3( logits_sketch_entity_pre, encoder_wo_cls, multi_head=True, name="logits_sketch_entity") logits_sketch_predicate_pre = logits_for_sketch_prediction( decoder_states, cls_state_predicate, self.num_predicate_labels - 3, self.cfg["hn"], self.cfg["clf_act_name"], 0., 1 - self.cfg["clf_dropout"], self.is_training, compress_mask=predicate_mask, scope="logits_sketch_predicate" ) logits_sketch_predicate = tf.concat([ tf.ones([bs, get_shape_list(logits_sketch_predicate_pre)[1], 3], tf.float32) * VERY_NEGATIVE_NUMBER, logits_sketch_predicate_pre, ], axis=-1) logits_sketch_type_pre = logits_for_sketch_prediction( decoder_states, cls_state_type, self.num_type_labels - 3, self.cfg["hn"], self.cfg["clf_act_name"], 0., 1 - self.cfg["clf_dropout"], self.is_training, compress_mask=type_mask, scope="logits_sketch_type" ) logits_sketch_type = tf.concat([ tf.ones([bs, get_shape_list(logits_sketch_type_pre)[1], 3], tf.float32) * VERY_NEGATIVE_NUMBER, logits_sketch_type_pre, ], axis=-1) logits_sketch_num_pre = logits_for_sketch_index( decoder_states, encoder_states_for_decoder, self.cfg["hn"], 0., 1 - self.cfg["clf_dropout"], self.is_training, compress_mask=num_mask, scope="logits_sketch_num_pre" ) logits_sketch_num = mask_v3( logits_sketch_num_pre, encoder_wo_cls, multi_head=True, name="logits_sketch_num") return logits_sketch_entity, logits_sketch_predicate, logits_sketch_type, logits_sketch_num
def masked_sparse2dense(input_tensor, reverse_spec, name=None): org_input_mask = reverse_spec['org_input_mask'] org_coords = reverse_spec['org_coords'] with tf.variable_scope(name or "masked_sparse2dense"): hn = get_shape_list(input_tensor)[-1] org_shape = get_shape_list(org_input_mask) org_shape.append(hn) return tf.scatter_nd(org_coords, input_tensor, org_shape) # [xx,hn]
def get_key_indices(tensor_input, special_token_list): # tensor_input 2 get_shape_list(tensor_input, 2) out_indices_list = [] for sp_token in special_token_list: out_indices_list.append( tf.cast( tf.argmax(tf.cast(tf.equal(tensor_input, sp_token), tf.int32), 1), tf.int32) ) return out_indices_list
def decompress_2nd_dim_from_batch(input_tensor, reverse_spec, name=None): # [nbs, 1,...] -> [bs,2d,...] with tf.name_scope(name or "decompress_2nd_dim_from_batch"): input_tensor_squeeze = tf.squeeze(input_tensor, 1) remain_shape = get_shape_list(input_tensor_squeeze)[1:] org_coords = reverse_spec["org_coords"] org_input_mask = reverse_spec["org_input_mask"] bs, sd = get_shape_list(org_input_mask) return tf.scatter_nd( org_coords, input_tensor_squeeze, [tf.to_int64(elem) for elem in [bs, sd] + remain_shape])
def decompress_seq_wrt_mask(tensor_input, reverse_dict): bs, tgt_len, hn = get_shape_list(tensor_input) src_len = get_shape_list(reverse_dict["src_mask"])[1] padded_tensor = tf.scatter_nd(reverse_dict["coord"], tensor_input, [bs, src_len + 1, hn]) out_tensor = padded_tensor[:, :-1] # bs,src_len,hn masked_out_tensor = mask_v3(out_tensor, reverse_dict["src_mask"], high_dim=True) return masked_out_tensor
def smoothed_softmax_cross_entropy_with_logits(**kwargs): logits = kwargs.get("logits") labels = kwargs.get("labels") label_smoothing = kwargs.get("label_smoothing") or 0.0 normalize = kwargs.get("normalize") if logits is None or labels is None: raise ValueError("Both logits and labels must be provided") if not label_smoothing: ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) return ce # adaptive for any rank vocab_size = get_shape_list(logits)[-1] n = tf.to_float(vocab_size - 1) p = 1.0 - label_smoothing q = label_smoothing / n soft_targets = tf.one_hot(tf.cast(labels, tf.int32), depth=vocab_size, on_value=p, off_value=q) xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=soft_targets) if not normalize: return xentropy normalizing = -(p * tf.log(p) + n * q * tf.log(q + 1e-20)) return xentropy - normalizing
def get_slice(tensor_input, start_idxs, end_idxs): # 1. the size of 1st dim of tensor_input, start_idxs, end_idxs must be equal # 2. the idxs is the 2nd dim of tensor input # 3. output: 1. a output tensor, 2. a mask tensor_shape = get_shape_list(tensor_input) bs = tensor_shape[0] sl = tensor_shape[1] extra_dims = tensor_shape[2:] if len(tensor_shape) > 2 else [] lens = end_idxs - start_idxs - 1 max_len = tf.reduce_max(lens) # target bool indicator indices_input = tf.tile(tf.expand_dims(tf.range(sl, dtype=tf.int32), 0), [bs, 1]) # bs, sl indices_new = indices_input - tf.expand_dims(start_idxs, 1) - 1 # bs, sl tgt_bool_indicator = tf.logical_and( tf.greater(indices_input, tf.expand_dims(start_idxs, 1)), tf.less(indices_input, tf.expand_dims(end_idxs, 1)), ) coord_in_input = tf.where(tgt_bool_indicator) # [n_true, 2] two_d_indices_new = tf.stack( # bs,sl,2 values=[ tf.tile(tf.expand_dims(tf.range(bs, dtype=tf.int32), 1), [1, sl]), indices_new, ], axis=-1 ) coord_in_output = tf.gather_nd(two_d_indices_new, coord_in_input) # [n_true, 2] gathered_tensor_input = tf.gather_nd(tensor_input, coord_in_input) # [n_true]+extra_dims tensor_output = tf.scatter_nd(coord_in_output, gathered_tensor_input, [bs, max_len] + extra_dims) mask_output = generate_mask_based_on_lens(lens, max_len) return tensor_output, mask_output
def top_k_to_coordinate(top_k_vec, prob_tensor=None, logits=None, dim=None, name=None): if isinstance(prob_tensor, type(None)): prob_tensor = tf.nn.softmax(logits, axis=-1)[..., dim] with tf.name_scope(name or "top_k_to_coordinate"): bs, sll = get_shape_list(prob_tensor, expected_rank=2) sorted_tensor = tf.contrib.framework.sort( prob_tensor, axis=-1, direction='DESCENDING') # bs,sll padded_sorted_tensor = tf.concat( [sorted_tensor, -tf.ones([bs, 1], sorted_tensor.dtype)], axis=-1) # [bs,sll+1] k_th_scores_indices = tf.stack( # [bs,2] [ tf.range(bs, dtype=tf.int32), top_k_vec, ], axis=-1) k_th_scores = tf.expand_dims(tf.gather_nd(padded_sorted_tensor, k_th_scores_indices), axis=-1) # [bs,1] mask_mat = tf.greater(prob_tensor, k_th_scores) # [bs,sll] return mask_matrix_to_coordinate(mask_mat)
def mask_generation(rep_mask, head_num, use_direction, attn_self, name=None): # this mask is for self-attention with tf.name_scope(name or 'mask_generation'): rep_shape = get_shape_list(rep_mask, 2) bs, sl = rep_shape # regular mask rep_mask_epd1 = tf.expand_dims(rep_mask, 1) # bs,1,sl rep_mask_epd2 = tf.expand_dims(rep_mask, 2) # bs,sl,1 rep_mask_mat = tf.logical_and(rep_mask_epd1, rep_mask_epd2) # bs,sl,sl # position mask sl_indices = tf.range(sl, dtype=tf.int32) sl_col, sl_row = tf.meshgrid(sl_indices, sl_indices) if use_direction: comp_func = tf.greater_equal if attn_self else tf.greater fw_mask = comp_func(sl_row, sl_col) # sl,sl bw_mask = comp_func(sl_col, sl_row) # sl,sl direct_mask = tf.stack([fw_mask, bw_mask], 0) # 2,sl,sl direct_mask = tf.reshape( # num,sl,sl tf.tile(tf.expand_dims(direct_mask, 1), [1, int(head_num / 2), 1, 1]), # 2,4,sl,sl [head_num, sl, sl]) else: if not attn_self: direct_mask = tf.tile(tf.expand_dims(tf.not_equal(sl_row, sl_col), 0), [head_num, 1, 1]) # n,sl,sl else: raise(ValueError, "A attention overself must be avoided without fw/bw information") final_mask = tf.logical_and( # bs,num,sl,sl tf.expand_dims(rep_mask_mat, 1), tf.expand_dims(direct_mask, 0)) return final_mask
def generate_label_mask(input_pos_ids, input_mask, sll): input_pos_ids = mask_v3(input_pos_ids + 1, input_mask) bs, sl = get_shape_list(input_pos_ids) sll_idxs = tf.tile(tf.expand_dims(tf.range(sll, dtype=tf.int32), 0), [bs, 1]) # bs,sl max_idxs = tf.reduce_max(input_pos_ids, axis=-1, keepdims=True) # [bs,1] return tf.less(sll_idxs, max_idxs)
def number_to_index(num_vec, name=None): # [3, 2, 0, 2, 1] -> [0, 0, 0, 1, 1, 3, 3, 4] with tf.name_scope(name or "number_to_index"): vec_len = get_shape_list(num_vec)[0] max_num = tf.reduce_max(num_vec) # [] idx_mat = tf.tile(tf.expand_dims(tf.range(vec_len, dtype=tf.int32), -1), [1, max_num]) # [len,num] num_mask = generate_mask_based_on_lens(num_vec, max_num) # [len,num] coords = tf.where(num_mask) # [new,2] return tf.gather_nd(idx_mat, coords) # [new]
def extend_batch_for_2nd_dim_compression(input_tensor, reverse_spec=None, num_vec=None, name=None): # [bs,...] -> [nbs,...] with tf.name_scope(name or "extend_batch_for_2nd_dim_compression"): if reverse_spec is not None: org_input_mask = reverse_spec["org_input_mask"] org_coords = reverse_spec["org_coords"] vec_len = get_shape_list(org_input_mask)[0] else: max_num = tf.reduce_max(num_vec) # [] org_input_mask = generate_mask_based_on_lens(num_vec, max_num) org_coords = tf.where(org_input_mask) vec_len = get_shape_list(num_vec)[0] max_num = get_shape_list(org_input_mask)[-1] idx_mat = tf.tile( tf.expand_dims(tf.range(vec_len, dtype=tf.int32), -1), [1, max_num]) num_indices = tf.expand_dims(tf.gather_nd(idx_mat, org_coords), -1) # [nbs,1] return tf.gather_nd(input_tensor, num_indices) # [nbs,...]
def combine_head(inp_tensor, name=None): with tf.name_scope(name or 'combine_head'): # [bs,hd_num,sl,hd_dim] as an example inp_shape = get_shape_list(inp_tensor) # [4] for [bs,hd_num,sl,hd_dim] # get hn from head_num * head_dim assert isinstance(inp_shape[1], int) and isinstance(inp_shape[-1], int) hn = inp_shape[1] * inp_shape[-1] # move head dim to -1 new_perm = list(range(len(inp_shape))) # [0,1,2,3] head_dim = new_perm.pop(1) # [0,2,3] new_perm.insert(-1, head_dim) # [0,2,1,3] inp_tensor_new_perm = tf.transpose(inp_tensor, new_perm) # [bs,sl,hd_num,hd_dim] # get new shape new_shape = get_shape_list(inp_tensor_new_perm)[:-2] + [ hn ] # [3] for [bs,sl,hn] # return reshaped tensor return tf.reshape(inp_tensor_new_perm, new_shape) # [bs,sl,hn]
def compress_2nd_dim_to_batch(input_tensor, num_vec, name=None): # [bs,sd,...] -> [nbs,...] with tf.name_scope(name or "compress_2nd_dim_to_batch"): bs, sd = get_shape_list(input_tensor)[:2] num_mask = generate_mask_based_on_lens(num_vec, sd) # [bs,sd] coords = tf.where(num_mask) # [nbs,2] reverse_spec = { "org_coords": coords, "org_input_mask": num_mask, } out_tensor = tf.gather_nd(input_tensor, coords) # [nbs,...] out_tensor = tf.expand_dims(out_tensor, 1) return out_tensor, reverse_spec
def attn_post_proc(attn_res, inter_hn=None, wd=0., keep_prob=1., residual_keep_prob=1., is_train=None, activation='relu', sparse_opt=False, scope=None, **kwargs): with tf.variable_scope(scope or "attn_res"): assert "mask" in kwargs if sparse_opt: x1, reverse_spec = masked_dense2sparse(attn_res, kwargs.get("mask")) else: x1 = attn_res y = bn_dense_layer_v2( x1, get_shape_list(attn_res)[-1], True, 0., "dense_layer", "linear", False, wd, keep_prob, is_train ) x2 = residual_connection(x1, y, is_train, residual_keep_prob, "res_con") res = residual_connection_with_dense( x2, inter_hn or 4*get_shape_list(attn_res)[-1], True, 0., "residual_connection_with_dense", activation, False, wd, keep_prob, is_train, residual_keep_prob ) if sparse_opt: res = masked_sparse2dense(res, reverse_spec) return res
def compress_seq_wrt_mask(tensor_input, tensor_mask): bs, sl, hn = get_shape_list(tensor_input) seq_lens = tf.reduce_sum(tf.cast(tensor_mask, tf.int32), -1) # sl max_len = tf.reduce_max(seq_lens) # [] new_mask = generate_mask_based_on_lens(seq_lens, max_len) # ======> to ensure every batch get same elem via padding pad_lens = max_len - seq_lens max_pad_len = tf.reduce_max(pad_lens) pad_mask = generate_mask_based_on_lens(pad_lens, max_pad_len) padded_tensor_mask = tf.concat([tensor_mask, pad_mask], axis=-1) # bs,sl+max_pad_len # new coord bs_idxs = generate_seq_idxs(bs, sl + max_pad_len, transpose=True) # bs,sl+max_pad_len sl_idxs = tf.concat( # bs,sl+max_pad_len [ generate_seq_idxs(bs, sl, transpose=False), # bs,sl -tf.ones([bs, max_pad_len], tf.int32) # bs, max_pad_len ], axis=-1) data_coord_map = tf.stack([bs_idxs, sl_idxs], axis=-1) # bs,sl+max_pad_len,2 padded_coord = tf.where(padded_tensor_mask) # bs*max_len,2 mapped_padded_coord_rsp = tf.gather_nd(data_coord_map, padded_coord) # bs*max_len,2 mapped_padded_coord = tf.reshape(mapped_padded_coord_rsp, [bs, max_len, 2]) # bs,max_len,2 gathered_data = tf.gather_nd(tensor_input, mapped_padded_coord) # bs,max_len,hn masked_gathered_data = mask_v3(gathered_data, new_mask, high_dim=True) reverse_dict = { "src_mask": tensor_mask, "tgt_mask": new_mask, "coord": mapped_padded_coord, # bs,max_len,2 } return masked_gathered_data, new_mask, reverse_dict
def split_head(inp_tensor, head_num, name=None): with tf.name_scope(name or 'split_head'): # [bs,sl,num] as an example inp_shape = get_shape_list(inp_tensor) # [3] for [bs,sl,hn] # head params hn = inp_shape[-1] assert hn % head_num == 0 head_dim = hn // head_num new_input_shape = inp_shape[:-1] + [head_num, head_dim ] # [4] for [bs,sl,hd_num,hd_dim] new_perm = list(range(len(new_input_shape))) # [0,1,2,3] head_dim = new_perm.pop(-2) # [0,1,3] new_perm.insert(1, head_dim) # [0,2,1,3] inp_tensor_hd = tf.reshape(inp_tensor, new_input_shape) # [bs,sl,hd_num,hd_dim] return tf.transpose(inp_tensor_hd, new_perm) # [bs,hd_num,sl,hd_dim]
def transform_pos_ids_to_wordpiece_idx(input_pos_ids, input_mask, sll): # 0 0 1 1 1 2 2 2 2 3 3 0 0 0 0 0 # bs,sl # bs, sl = get_shape_list(input_pos_ids) diff_pos = mask_v3( # bs,sl input_pos_ids - tf.concat([tf.zeros([bs, 1], dtype=tf.int32), input_pos_ids[:, :-1]], axis=1), input_mask) sl_idxs = tf.tile(tf.expand_dims(tf.range(sl, dtype=tf.int32), 0), [bs, 1]) # bs,sl word_start_index = diff_pos * sl_idxs # bs, sl # remove all 0 value slx_s = tf.reduce_sum(diff_pos, axis=-1) # the number of non-zero for each example slx = tf.reduce_max(slx_s) # sly_s = slx - slx_s # the number of non-zero for padding sly = tf.reduce_max(sly_s) # padding_seq = tf.cast(generate_mask_based_on_lens(sly_s, sly), tf.int32) valid_data_mask = generate_mask_based_on_lens(slx_s, slx) # bs, slx padded_word_start_index = tf.concat([word_start_index, padding_seq], axis=-1) # bs,sl+sly data_coord = tf.reshape( # bs, slx tf.where(tf.cast(padded_word_start_index, tf.bool)), # bs*slx,2 [bs, slx, 2]) word_start = tf.concat( # bs, sll [ tf.zeros([bs, 1], dtype=tf.int32), mask_v3(tf.gather_nd(padded_word_start_index, data_coord), valid_data_mask), # bs,slx tf.zeros([bs, sll - slx - 1], dtype=tf.int32) ], axis=1) bs_idxs = generate_seq_idxs(bs, sl, transpose=True) # bs,sl base_coord = tf.stack([bs_idxs, input_pos_ids], axis=-1) # bs,sl,2 base_value = tf.gather_nd(word_start, base_coord) # bs,sl # finally outputs = mask_v3(sl_idxs - base_value, input_mask) # bs,sl return outputs
def logits_for_sketch_index( decoder_states, encoder_states, hn=None, wd=0., keep_prob=1.0, is_train=None, compress_mask=None, scope=None, ): compressing = not isinstance(compress_mask, type(None)) hn = hn or get_shape_list(decoder_states)[-1] with tf.variable_scope(scope or "logits_for_sketch_index"): if compressing: new_decoder_states, _, rev_d = compress_seq_wrt_mask( decoder_states, compress_mask) else: new_decoder_states = decoder_states rev_d = None with tf.variable_scope("projection"): encoder_states_map = bn_dense_layer_v2(encoder_states, hn, True, 0., "encoder_states_map", "linear", False, wd, keep_prob, is_train) decoder_states_map = bn_dense_layer_v2(new_decoder_states, hn, True, 0., "decoder_states_map", "linear", False, wd, keep_prob, is_train) with tf.variable_scope("bi_linear"): bilinear_pre = bn_dense_layer_v2(decoder_states_map, hn, False, 0., "bilinear_map", "linear", False, wd, keep_prob, is_train) logits = tf.matmul(bilinear_pre, encoder_states_map, transpose_b=True) # bs,dsl,esl if compressing: logits = decompress_seq_wrt_mask(logits, rev_d) return logits
def residual_connection_with_dense(x, hn, bias, bias_start=0.0, scope=None, activation='relu', enable_bn=False, wd=0., keep_prob=1.0, is_train=None, residual_keep_prob=1.): with tf.variable_scope(scope or 'residual_connection_with_dense'): y1 = bn_dense_layer_v2(x, hn, bias, bias_start, "dense_layer_1", activation, enable_bn, wd, keep_prob, is_train) y2 = bn_dense_layer_v2(y1, get_shape_list(x)[-1], bias, bias_start, "dense_layer_2", "linear", enable_bn, wd, keep_prob, is_train) return residual_connection(x, y2, is_train, residual_keep_prob, 'residual_connection')
def _setup_training(self): self.logits_dict = self._build_network() self.loss, self.loss_dict = self._build_loss() self.prediction_dict = self._build_prediction() self.log_num_params() # to build train op self.train_op = optimization.create_optimizer( self.loss, self.cfg['learning_rate'], self.num_training_steps, int(self.num_training_steps * self.cfg['warmup_proportion']), use_tpu=False ) self.run_dict = { "loss": self.loss, "loss_seq2seq": self.loss_dict["seq2seq"], "loss_seq_label": self.loss_dict["seq_label"], "train_op": self.train_op, } # for decoder beam search # # 1. for distribution seq2seq_dist_wo_pad = tf.nn.softmax(self.decoder_dict["logits_seq2seq_run"]) # bs,1,nl-1 self.decoder_dict["distribution_seq2seq_run"] = tf.concat( # bs,1,nl [ tf.zeros(get_shape_list(seq2seq_dist_wo_pad)[:2] + [1]), # bs,1,1 seq2seq_dist_wo_pad, ], -1) self.decoder_dict["distribution_sketch_entity_run"] = tf.nn.softmax( self.decoder_dict["logits_sketch_entity_run"]) self.decoder_dict["distribution_sketch_predicate_run"] = tf.nn.softmax( self.decoder_dict["logits_sketch_predicate_run"]) self.decoder_dict["distribution_sketch_type_run"] = tf.nn.softmax( self.decoder_dict["logits_sketch_type_run"]) self.decoder_dict["distribution_sketch_num_run"] = tf.nn.softmax( self.decoder_dict["logits_sketch_num_run"])
def mask_matrix_to_coordinate(mask_mat, name=None): with tf.name_scope(name or "mask_matrix_to_coordinate"): bs, sll = get_shape_list(mask_mat, expected_rank=2) # lens real_lens = tf.reduce_sum(tf.cast(mask_mat, tf.int32), axis=-1) # bs max_real_len = tf.reduce_max(real_lens, axis=0) # [] pad_lens = max_real_len - real_lens max_pad_len = tf.reduce_max(pad_lens, axis=0) # mask generation pad_mask_mat = generate_mask_based_on_lens(pad_lens, max_pad_len) coord_mask = generate_mask_based_on_lens(real_lens, max_real_len) # coord generation padded_mask_mat = tf.concat([mask_mat, pad_mask_mat], axis=-1) flat_coords = tf.where(padded_mask_mat) # [bs*max_real_len,2] coords = tf.reshape(flat_coords, [bs, max_real_len, 2]) # [bs,max_real_len] coords = mask_v3(coords, coord_mask, high_dim=True) return coords, coord_mask
def direct_mask_generation(rep_mask, direct, attn_self, name=None): assert direct in ["forward", "backward"] with tf.name_scope(name or 'direct_mask_generation'): rep_shape = get_shape_list(rep_mask, 2) bs, sl = rep_shape # regular mask rep_mask_epd1 = tf.expand_dims(rep_mask, 1) # bs,1,sl rep_mask_epd2 = tf.expand_dims(rep_mask, 2) # bs,sl,1 rep_mask_mat = tf.logical_and(rep_mask_epd1, rep_mask_epd2) # bs,sl,sl # position mask sl_indices = tf.range(sl, dtype=tf.int32) sl_col, sl_row = tf.meshgrid(sl_indices, sl_indices) comp_func = tf.greater_equal if attn_self else tf.greater if direct == "forward": direct_mask = comp_func(sl_row, sl_col) # sl,sl elif direct == "backward": direct_mask = comp_func(sl_col, sl_row) else: raise AttributeError direct_mask = tf.tile(tf.expand_dims(direct_mask, 0), [bs, 1, 1]) return tf.logical_and(rep_mask_mat, direct_mask)
def __init__(self, cfg, tokenizer, data_type, labels_dict, max_sequence_len, num_training_steps, scope): if "level_for_dec" in cfg and cfg['level_for_dec'] >= 0: num_hidden_layers = cfg['level_for_dec'] + 1 else: num_hidden_layers = None if "hidden_size_input" in cfg and cfg['hidden_size_input'] > 0: hidden_size = cfg['hidden_size_input'] else: hidden_size = None if "num_attention_heads_input" in cfg and cfg['num_attention_heads_input'] > 0: num_attention_heads = cfg['num_attention_heads_input'] else: num_attention_heads = None if "intermediate_size_input" in cfg and cfg['intermediate_size_input'] > 0: intermediate_size = cfg['intermediate_size_input'] else: intermediate_size = None if "hidden_dropout_prob_input" in cfg and cfg['hidden_dropout_prob_input'] > 0: hidden_dropout_prob = cfg['hidden_dropout_prob_input'] else: hidden_dropout_prob = None if "attention_probs_dropout_prob_input" in cfg and cfg['attention_probs_dropout_prob_input'] > 0: attention_probs_dropout_prob = cfg['attention_probs_dropout_prob_input'] else: attention_probs_dropout_prob = None super(ModelBertTemplate, self).__init__( cfg, is_paired_data=False, scope=scope, num_hidden_layers=num_hidden_layers, hidden_size=hidden_size, num_attention_heads=num_attention_heads, intermediate_size=intermediate_size, hidden_dropout_prob=hidden_dropout_prob, attention_probs_dropout_prob=attention_probs_dropout_prob, ) self.data_type = data_type self.labels_dict = labels_dict self.max_sequence_len = max_sequence_len self.num_training_steps = num_training_steps self.vocab = tokenizer.vocab self.tokenizer = tokenizer self.input_pos_ids = tf.placeholder(tf.int32, [None, None]) self.loss_gain_wrt_qt = tf.placeholder(tf.float32, [None]) # ==== an introduction to lengths ===== # [prev_q] [sep] [prev_a] [sep1] [cur_q] [cls] # sl: seq len, wordpiece-level, ([prev_q] [sep] [prev_a] [sep1] [cur_q] [cls]) # sll: seq label len, token-level, ([prev_q] [sep] [prev_a] [sep1] [cur_q]) # asl: all seq len, token_level, ([prev_q] [sep] [prev_a] [sep1] [cur_q] [cls]) # # others, # pl, piece len, the max len of word pieces belonging to a word # ====== labels ===== # 1. EO self.num_EO_labels = len(labels_dict["EOs"]["labels"]) self.num_type_labels = len(labels_dict["types"]["labels"]) self.EO_label = tf.placeholder(tf.int32, [None, None]) # [bs, sll] with [0,nel) self.entity_type_label = tf.placeholder(tf.int32, [None, None]) # [bs, sll] with (1,ntl) # 2. Sketches: include sketch itself and leaves labels: entity, predicate, type and num self.sos_id = labels_dict["sketch"]["labels"].index(SOS_TOKEN) self.eos_id = labels_dict["sketch"]["labels"].index(EOS_TOKEN) self.num_predicate_labels = len(labels_dict["predicates"]["labels"]) self.num_sketch_labels = len(labels_dict["sketch"]["labels"]) self.sketch_label = tf.placeholder(tf.int32, [None, None]) # bs,dsl+1 self.sketch_output_ids = self.sketch_label[:, 1:] # bs,dsl self.sketch_mask = tf.cast(self.sketch_output_ids, tf.bool) # bs,dsl self.sketch_input_ids = self.sketch_label[:, :-1] * tf.cast(self.sketch_mask, tf.int32) # bs,dsl self.sketch_entity = tf.placeholder(tf.int32, [None, None]) # bs,dsl self.sketch_predicate = tf.placeholder(tf.int32, [None, None]) # bs,dsl self.sketch_type = tf.placeholder(tf.int32, [None, None]) # bs,dsl self.sketch_num = tf.placeholder(tf.int32, [None, None]) # bs,dsl # # 2.1 masks self.sketch_entity_mask = tf.not_equal(self.sketch_entity, -1) self.sketch_predicate_mask = tf.not_equal(self.sketch_predicate, 0) self.sketch_type_mask = tf.not_equal(self.sketch_type, 0) self.sketch_num_mask = tf.not_equal(self.sketch_num, -1) # lens self.asl = tf.reduce_max(self.input_pos_ids) + 1 # all sequence length (token-level) self.sll = get_shape_list(self.EO_label)[-1] # sequence labeling length (token-level) self.wordpiece_idx = transform_pos_ids_to_wordpiece_idx( # bs,asl self.input_pos_ids, self.input_mask, self.asl) self.pl = tf.reduce_max(self.wordpiece_idx) + 1 # masks self.seq_label_mask = tf.cast(self.EO_label, bool) # bs,sll self.wordpiece_mask = get_word_level_split( # bs,asl,pl self.input_mask, self.input_pos_ids, self.wordpiece_idx, self.input_mask, self.asl, self.pl ) # special token indices self.unk_id, self.cls_id, self.sep_id, self.empty_id, self.sep1_id, self.pad_id = convert_tokens_to_ids( self.vocab, [ SPECIAL_TOKENS["UNK"], SPECIAL_TOKENS["CLS"], SPECIAL_TOKENS["SEP"], SPECIAL_TOKENS["EMPTY"], SPECIAL_TOKENS["SEP1"], SPECIAL_TOKENS["PAD"]]) # for the decoding self.dec_input_emb_mat = tf.get_variable( "dec_input_emb_mat", [self.num_sketch_labels, self.cfg["hn"]], initializer=tf.truncated_normal_initializer(0, 0.05) ) # for the key indices first_ids = get_word_level_split( # bs,sl -> bs,asl,pl -> bs,asl self.input_ids, self.input_pos_ids, self.wordpiece_idx, self.input_mask, self.asl, self.pl )[..., 0] # get the 1st id in each wordpieces self.sep_indices = tf.stack( get_key_indices(first_ids, [self.sep_id, self.sep1_id, self.cls_id]), axis=-1) self.decoder_dict = { # placeholders: don't forget the "encoder_states_placeholder": tf.placeholder(tf.float32, [None, None, cfg["hn"]]), # bs,sl,hn "encoder_output_for_predicate_placeholder": tf.placeholder(tf.float32, [None, cfg["hn"]]), "encoder_output_for_type_placeholder": tf.placeholder(tf.float32, [None, cfg["hn"]]), "encoder_ids_placeholder": tf.placeholder(tf.float32, [None, None]), # bs,sl "decoder_history_placeholder": tf.placeholder(tf.float32, [None, cfg["decoder_layer"], None, cfg["hn"]]), # bs,t,hn "decoder_ids_placeholder": tf.placeholder(tf.int32, [None, 1]), "is_training_placeholder": self.is_training, # intermediate tensor "encoder_states_run": None, "encoder_output_for_predicate_run": None, "encoder_output_for_type_run": None, "decoder_history_run": None, "logits_seq2seq_run": None, "logits_sketch_entity_run": None, "logits_sketch_predicate_run": None, "logits_sketch_type_run": None, "logits_sketch_num_run": None, } self.decoder_dict["encoder_mask"] = tf.cast(self.decoder_dict["encoder_ids_placeholder"], tf.bool) self.logits_dict = None self.loss_dict = None self.prediction_dict = None self.loss = None self.train_op = None self.run_dict = None self._setup_training()
def _build_loss(self): # for seq label joint_label = tf.where( # 0 for empty or pad tf.logical_and(tf.greater_equal(self.EO_label, 2), tf.greater_equal(self.entity_type_label, 2)), (self.EO_label - 2) + 4 * (self.entity_type_label - 2) + 1, tf.zeros(get_shape_list(self.EO_label), tf.int32) ) joint_label_rsp = tf.reshape(joint_label, [self.bs * self.sll]) logits_seq_label_rsp = tf.reshape(self.logits_dict["seq_label"], [-1, get_shape_list(self.logits_dict["seq_label"])[-1]]) losses_seq_label_rsp = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=joint_label_rsp, logits=logits_seq_label_rsp ) losses_seq_label = tf.reshape(losses_seq_label_rsp, [self.bs, self.sll]) seq_label_mask_tf = tf.cast(self.seq_label_mask, tf.float32) seq_label_weights = tf.where( tf.greater(joint_label, 0), tf.ones_like(losses_seq_label) * self.cfg["pos_gain"], tf.ones_like(losses_seq_label) ) * seq_label_mask_tf loss_seq_label = \ tf.reduce_sum(losses_seq_label * seq_label_weights) / tf.reduce_sum(seq_label_weights) # for sequence to sequence # # 1. sketch loss label_seq2seq = tf.where( self.sketch_mask, self.sketch_output_ids - 1, # for valid token - 1 self.sketch_output_ids ) losses_sketch = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=label_seq2seq, logits=self.logits_dict["seq2seq"] ) # # 2. leaves losses # # # 2.1 entity losses_sketch_entity = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=self.sketch_entity * tf.cast(self.sketch_entity_mask, tf.int32), logits=self.logits_dict["sketch_entity"] ) * tf.cast(self.sketch_entity_mask, tf.float32) # # # 2.2 predicate losses_sketch_predicate = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=self.sketch_predicate, logits=self.logits_dict["sketch_predicate"] ) * tf.cast(self.sketch_predicate_mask, tf.float32) # # # 2.3 type losses_sketch_type = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=self.sketch_type, logits=self.logits_dict["sketch_type"] ) * tf.cast(self.sketch_type_mask, tf.float32) # # # 2.4 num losses_sketch_num = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=self.sketch_num * tf.cast(self.sketch_num_mask, tf.int32), logits=self.logits_dict["sketch_num"] ) * tf.cast(self.sketch_num_mask, tf.float32) # # 3 combine leaves' losses losses_sketch_leaves = \ (losses_sketch_entity + losses_sketch_predicate + losses_sketch_type + losses_sketch_num) * \ tf.cast(self.sketch_mask, tf.float32) # # 4. combine to the sketch loss losses_seq2seq = losses_sketch + losses_sketch_leaves # # 5. calc final loss sketch_mask_ft = tf.cast(self.sketch_mask, tf.float32) # bs,sl sketch_mask_int = tf.cast(self.sketch_mask, tf.int32) # bs,sl sketch_ex_mask = tf.cast(tf.reduce_sum(sketch_mask_int, -1), tf.bool) # bs sketch_ex_mask_ft = tf.cast(sketch_ex_mask, tf.float32) # bs seq_deno = tf.reduce_sum(sketch_mask_ft, -1) seq_deno = tf.where( tf.greater(seq_deno, 0.), seq_deno, tf.ones_like(seq_deno) * 1e-6, ) loss_seq2seq_example = tf.reduce_sum(sketch_mask_ft * losses_seq2seq, -1) / seq_deno # bs loss_seq2seq_example = loss_seq2seq_example * self.loss_gain_wrt_qt batch_deno = tf.reduce_sum(sketch_ex_mask_ft * self.loss_gain_wrt_qt) batch_deno = tf.where( tf.greater(batch_deno, 0.), batch_deno, tf.ones_like(batch_deno) * 1e-6, ) loss_seq2seq = tf.reduce_sum(sketch_ex_mask_ft * loss_seq2seq_example) / batch_deno opt_loss = self.cfg["seq_label_loss_weight"]*loss_seq_label + \ self.cfg["seq2seq_loss_weight"] * loss_seq2seq return opt_loss, { "seq_label": loss_seq_label, "seq2seq": loss_seq2seq, }
def bn_dense_layer_multi_head(input_tensor, hn, bias, bias_start=0.0, scope=None, activation='relu', enable_bn=False, wd=0., keep_prob=1.0, is_train=None, dup_num=1, merge_var=False): assert not enable_bn """The input could be >3-d and the 1d-for bs, 2d for head, -1d for hn""" act_fn = act_name2fn(activation) with tf.variable_scope(scope or 'bn_dense_layer_multi_head'): input_tensor = dropout(input_tensor, keep_prob, is_train) # dropout [bs,hd,sl,dim] # the comments using 4d [bs,hd,sl,dim] for example input_shape = get_shape_list(input_tensor) # [4] for [bs,hd,sl,dim] assert len(input_shape) >= 3 # exchange 1st and 2nd dimension perm_t = list(range(len(input_shape))) # [0,1,2,3] perm_t[0], perm_t[1] = perm_t[1], perm_t[0] # [1,0,2,3] input_tensor_t = tf.transpose(input_tensor, perm_t) # [hd,bs,sl,dim] # merge and reshape input_shape_t = get_shape_list( input_tensor_t) # [4] for [hd,bs,sl,dim] dims_merge = input_shape_t[1:-1] # [2] for [bs,sl] new_dim = reduce(mul, dims_merge) # bs*sl new_shape = [input_shape_t[0], new_dim, input_shape_t[-1]] # [3] for [hd,bs*sl,dim] input_tensor_rsp = tf.reshape(input_tensor_t, new_shape) # [hd,bs*sl,dim] # dense layer hd_num = new_shape[0] # head num hd_dim = new_shape[-1] # head dim if merge_var: weight = tf.get_variable('W', shape=[hd_num, hd_dim, hn * dup_num]) else: weight_list = [] for i in range(hd_num): sub_weight_list = [] for j in range(dup_num): sub_weight_list.append( tf.get_variable('W_%d_%d' % (i, j), shape=[hd_dim, hn])) weight_list.append( tf.concat(sub_weight_list, -1 ) if dup_num > 1 else sub_weight_list[0]) weight = tf.stack(weight_list, 0) out_rsp = tf.matmul(input_tensor_rsp, weight) # hd_num, bs*sl, hn if bias: if merge_var: bias_val = tf.get_variable( 'bias', shape=[hd_num, 1, hn], dtype=tf.float32, initializer=tf.constant_initializer(bias_start)) else: bias_list = [] for i in range(hd_num): sub_bias_list = [] for j in range(dup_num): sub_bias_list.append( tf.get_variable( 'bias_%d_%d' % (i, j), shape=[1, hn], dtype=tf.float32, initializer=tf.constant_initializer( bias_start))) bias_list.append( tf.concat(sub_bias_list, -1 ) if dup_num > 1 else sub_bias_list[0]) bias_val = tf.stack(bias_list, 0) out_rsp = out_rsp + bias_val # hd_num, bs*sl, hn # un-merge output_shape_t = [new_shape[0] ] + dims_merge + [hn] # [4] for [hd,bs,sl,new_dim] output_t = tf.reshape(out_rsp, output_shape_t) # [hd,bs,sl,new_dim] # transpose output = tf.transpose(output_t, perm_t) # [bs,hd,sl,new_dim] if wd: tf.add_to_collection('reg_vars', weight) return act_fn(output)
def bn_dense_layer_v2(input_tensor, hn, bias, bias_start=0.0, scope=None, activation='relu', enable_bn=False, wd=0., keep_prob=1.0, is_train=None, dup_num=1, merge_var=False): act_fn = act_name2fn(activation) with tf.variable_scope(scope or 'bn_dense_layer'): input_tensor = dropout(input_tensor, keep_prob, is_train) # the comment use a 3d tensor [bs,sl,hn] as a example input_shape = get_shape_list(input_tensor) # [3] assert len(input_shape) >= 2 # at least [bs,hn] # merge dims_merge = input_shape[:-1] # [all unrelated dims] new_dim = reduce(mul, dims_merge) # get the merged dim new_shape = [new_dim, input_shape[-1]] # new shape for matmul [2] input_tensor_rsp = tf.reshape(input_tensor, new_shape) # [xx,dim] # dense layer input_dim = new_shape[-1] if merge_var: weight = tf.get_variable('W', shape=[input_dim, hn * dup_num], dtype=tf.float32) else: weight_list = [] for i in range(dup_num): weight_list.append( tf.get_variable('W_%d' % i, shape=[input_dim, hn])) weight = tf.concat(weight_list, -1) output_rsp = tf.matmul(input_tensor_rsp, weight) if bias: if merge_var or dup_num == 1: bias_val = tf.get_variable( 'bias', shape=[hn * dup_num], dtype=tf.float32, initializer=tf.constant_initializer(bias_start)) else: bias_list = [] for i in range(dup_num): bias_list.append( tf.get_variable( 'bias_%d' % i, shape=[hn], dtype=tf.float32, initializer=tf.constant_initializer(bias_start))) bias_val = tf.concat(bias_list, -1) output_rsp += bias_val # output reshape output_shape = dims_merge + [hn * dup_num] # [3] for [bs,sl,new_hn] output = tf.reshape(output_rsp, output_shape) # [bs,sl,new_hn] if enable_bn: output = tf.contrib.layers.batch_norm(output, center=True, scale=True, is_training=is_train, updates_collections=None, decay=0.9, scope='bn') if wd: tf.add_to_collection('reg_vars', weight) return act_fn(output)