def _forward(self, is_training, split_placeholders, **kwargs): if not is_training: return super()._forward(is_training, split_placeholders, **kwargs) aug_input_ids = tf.boolean_mask( split_placeholders['aug_input_ids'], mask=(1.0 - split_placeholders['is_supervised']), axis=0) aug_input_mask = tf.boolean_mask( split_placeholders['aug_input_mask'], mask=(1.0 - split_placeholders['is_supervised']), axis=0) aug_segment_ids = tf.boolean_mask( split_placeholders['aug_segment_ids'], mask=(1.0 - split_placeholders['is_supervised']), axis=0) input_ids = tf.concat([split_placeholders['input_ids'], aug_input_ids], axis=0) input_mask = tf.concat( [split_placeholders['input_mask'], aug_input_mask], axis=0) segment_ids = tf.concat( [split_placeholders['segment_ids'], aug_segment_ids], axis=0) encoder = BERTEncoder(bert_config=self.bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, scope='bert', drop_pooler=self._drop_pooler, **kwargs) encoder_output = encoder.get_pooled_output() label_ids = split_placeholders['label_ids'] is_expanded = tf.zeros_like(label_ids, dtype=tf.float32) batch_size = util.get_shape_list(aug_input_ids)[0] aug_is_expanded = tf.ones((batch_size), dtype=tf.float32) is_expanded = tf.concat([is_expanded, aug_is_expanded], axis=0) decoder = UDADecoder( is_training=is_training, input_tensor=encoder_output, is_supervised=split_placeholders['is_supervised'], is_expanded=is_expanded, label_ids=label_ids, label_size=self.label_size, sample_weight=split_placeholders.get('sample_weight'), scope='cls/seq_relationship', global_step=self._global_step, num_train_steps=self.total_steps, uda_softmax_temp=self._uda_softmax_temp, uda_confidence_thresh=self._uda_confidence_thresh, tsa_schedule=self._tsa_schedule, **kwargs) (total_loss, losses, probs, preds) = decoder.get_forward_outputs() return (total_loss, losses, probs, preds)
def _forward(dilated_ids, dilated_mask): logits = self._bert_forward( bert_config, dilated_ids, dilated_mask, batch_size, dilated_seq_length, tilda_embeddings=tilda_embeddings) output_ids = tf.argmax(logits, axis=-1) output_ids = tf.cast(output_ids, dtype=tf.int32) # special padding (using `spad` token) equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32) equal_zero = tf.reduce_sum(equal_zero, axis=-1) right_pad = spad_id * tf.sequence_mask( equal_zero, dilated_seq_length, dtype=tf.int32) paded = tf.concat([output_ids, right_pad], axis=-1) # extract ids of length `max_seq_length` flattened_padded = tf.reshape(paded, [-1]) is_valid = tf.cast(tf.greater(flattened_padded, 0), dtype=tf.int32) flattened_valid = tf.boolean_mask(flattened_padded, is_valid) valid = tf.reshape(flattened_valid, [batch_size, dilated_seq_length]) cutted_valid = valid[:, :max_seq_length] # replace `spad` token with `pad` non_spad_mask = tf.cast(tf.not_equal( cutted_valid, spad_id), dtype=tf.int32) output_ids = cutted_valid * non_spad_mask output_length = tf.reduce_sum(non_spad_mask, axis=-1) # dilate reshaped_ids = tf.reshape(output_ids, [batch_size, max_seq_length, 1]) reshaped_mask = tf.reshape( tf.sequence_mask(output_length, max_seq_length, dtype=tf.int32), [batch_size, max_seq_length, 1]) concat_ids = tf.concat( [reshaped_ids, tf.zeros_like(reshaped_ids)], axis=-1) concat_mask = tf.concat([ reshaped_mask, tf.zeros_like(reshaped_mask, dtype=tf.int32) ], axis=-1) dilated_ids = tf.reshape(concat_ids, [batch_size, max_seq_length * 2]) dilated_mask = tf.reshape(concat_mask, [batch_size, max_seq_length * 2]) return dilated_ids, dilated_mask
def _forward(dilated_ids, dilated_mask): logits = self._bert_forward( bert_config, dilated_ids, dilated_mask, batch_size, dilated_seq_length, tilda_embeddings=tilda_embeddings) output_ids = tf.argmax(logits, axis=-1) output_ids = tf.cast(output_ids, dtype=tf.int32) equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32) equal_zero = tf.reduce_sum(equal_zero, axis=-1) right_pad = spad_id * tf.sequence_mask( equal_zero, dilated_seq_length, dtype=tf.int32) paded = tf.concat([output_ids, right_pad], axis=-1) flattened_padded = tf.reshape(paded, [-1]) is_valid = tf.cast(tf.greater(flattened_padded, 0), dtype=tf.int32) flattened_valid = tf.boolean_mask(flattened_padded, is_valid) valid = tf.reshape(flattened_valid, [batch_size, dilated_seq_length]) cutted_valid = valid[:, :max_seq_length] nonpad_mask = tf.cast(tf.not_equal(cutted_valid, spad_id), dtype=tf.int32) output_ids = cutted_valid * nonpad_mask reshaped = tf.reshape(output_ids, [batch_size, max_seq_length, 1]) concatenated = tf.concat( [reshaped, tf.zeros_like(reshaped)], axis=-1) dilated_ids = tf.reshape(concatenated, [batch_size, max_seq_length * 2]) input_mask = tf.reduce_sum(nonpad_mask, axis=-1) dilated_mask = tf.sequence_mask(input_mask, dilated_seq_length, dtype=tf.int32) return dilated_ids, dilated_mask
def _expand_features(module, split_placeholders): inputs = split_placeholders['input'] target = split_placeholders['target'] is_masked = tf.cast(split_placeholders['is_masked'], tf.bool) batch_size = tf.shape(inputs)[0] non_reuse_len = module.max_seq_length - module.reuse_seq_length assert (module.perm_size <= module.reuse_seq_length and module.perm_size <= non_reuse_len) (perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0) = \ _local_perm( inputs[:, :module.reuse_seq_length], target[:, :module.reuse_seq_length], is_masked[:, :module.reuse_seq_length], module.perm_size, module.reuse_seq_length) (perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1) = \ _local_perm( inputs[:, module.reuse_seq_length:], target[:, module.reuse_seq_length:], is_masked[:, module.reuse_seq_length:], module.perm_size, non_reuse_len) perm_mask_0 = tf.concat([ tf.cast(perm_mask_0, dtype=tf.float32), tf.ones([batch_size, module.reuse_seq_length, non_reuse_len]) ], axis=2) perm_mask_1 = tf.concat([ tf.zeros([batch_size, non_reuse_len, module.reuse_seq_length]), tf.cast(perm_mask_1, dtype=tf.float32) ], axis=2) perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=1) target = tf.concat([target_0, target_1], axis=1) target_mask = tf.concat([target_mask_0, target_mask_1], axis=1) input_k = tf.concat([input_k_0, input_k_1], axis=1) input_q = tf.concat([input_q_0, input_q_1], axis=1) if module._num_predict is not None: #TODO(geying): convert tensors from 1-D to 2-D indices = tf.range(module.max_seq_length, dtype=tf.int64) indices = tf.reshape(indices, [-1, module.max_seq_length]) indices = tf.tile(indices, [batch_size, 1]) bool_target_mask = tf.cast(target_mask, tf.bool) indices = tf.boolean_mask(indices, bool_target_mask) ##### extra padding due to CLS/SEP introduced after prepro actual_num_predict = tf.shape(indices)[1] pad_len = module._num_predict - actual_num_predict ##### target_mapping target_mapping = tf.one_hot(indices, module.max_seq_length, dtype=tf.float32) paddings = tf.zeros([pad_len, module.max_seq_length], dtype=target_mapping.dtype) target_mapping = tf.concat([target_mapping, paddings], axis=0) split_placeholders['target_mapping'] = tf.reshape( target_mapping, [-1, module._num_predict, module.max_seq_length]) ##### target target = tf.boolean_mask(target, bool_target_mask) paddings = tf.zeros([pad_len], dtype=target.dtype) target = tf.concat([target, paddings], axis=0) split_placeholders['target'] = tf.reshape(target, [-1, module._num_predict]) ##### target mask target_mask = tf.concat([ tf.ones([batch_size, actual_num_predict], dtype=tf.float32), tf.zeros([batch_size, pad_len], dtype=tf.float32) ], axis=1) split_placeholders['target_mask'] = tf.reshape( target_mask, [-1, module._num_predict]) else: split_placeholders['target'] = tf.reshape(target, [-1, module.max_seq_length]) split_placeholders['target_mask'] = tf.reshape( target_mask, [-1, module.max_seq_length]) # reshape back to fixed shape split_placeholders['perm_mask'] = tf.reshape( perm_mask, [-1, module.max_seq_length, module.max_seq_length]) split_placeholders['input_k'] = tf.reshape(input_k, [-1, module.max_seq_length]) split_placeholders['input_q'] = tf.reshape(input_q, [-1, module.max_seq_length]) return split_placeholders
def __init__(self, is_training, input_tensor, is_supervised, is_expanded, label_ids, label_size=2, sample_weight=None, scope='cls/seq_relationship', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, global_step=None, num_train_steps=None, uda_softmax_temp=-1, uda_confidence_thresh=-1, tsa_schedule='linear', **kwargs): super().__init__(**kwargs) is_supervised = tf.cast(is_supervised, tf.float32) is_expanded = tf.cast(is_expanded, tf.float32) hidden_size = input_tensor.shape.as_list()[-1] with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1) with tf.variable_scope('sup_loss'): # reshape sup_ori_log_probs = tf.boolean_mask(log_probs, mask=(1.0 - is_expanded), axis=0) sup_log_probs = tf.boolean_mask(sup_ori_log_probs, mask=is_supervised, axis=0) sup_label_ids = tf.boolean_mask(label_ids, mask=is_supervised, axis=0) self.preds['preds'] = tf.argmax(sup_ori_log_probs, axis=-1) one_hot_labels = tf.one_hot(sup_label_ids, depth=label_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum( one_hot_labels * sup_log_probs, axis=-1) loss_mask = tf.ones_like(per_example_loss, dtype=tf.float32) correct_label_probs = tf.reduce_sum(one_hot_labels * tf.exp(sup_log_probs), axis=-1) if is_training and tsa_schedule: tsa_start = 1.0 / label_size tsa_threshold = get_tsa_threshold(tsa_schedule, global_step, num_train_steps, tsa_start, end=1) larger_than_threshold = tf.greater(correct_label_probs, tsa_threshold) loss_mask = loss_mask * ( 1 - tf.cast(larger_than_threshold, tf.float32)) loss_mask = tf.stop_gradient(loss_mask) per_example_loss = per_example_loss * loss_mask if sample_weight is not None: sup_sample_weight = tf.boolean_mask(sample_weight, mask=is_supervised, axis=0) per_example_loss *= tf.cast(sup_sample_weight, dtype=tf.float32) sup_loss = (tf.reduce_sum(per_example_loss) / tf.maximum(tf.reduce_sum(loss_mask), 1)) self.losses['supervised'] = per_example_loss with tf.variable_scope('unsup_loss'): # reshape ori_log_probs = tf.boolean_mask(sup_ori_log_probs, mask=(1.0 - is_supervised), axis=0) aug_log_probs = tf.boolean_mask(log_probs, mask=is_expanded, axis=0) sup_ori_logits = tf.boolean_mask(logits, mask=(1.0 - is_expanded), axis=0) ori_logits = tf.boolean_mask(sup_ori_logits, mask=(1.0 - is_supervised), axis=0) unsup_loss_mask = 1 if uda_softmax_temp != -1: tgt_ori_log_probs = tf.nn.log_softmax(ori_logits / uda_softmax_temp, axis=-1) tgt_ori_log_probs = tf.stop_gradient(tgt_ori_log_probs) else: tgt_ori_log_probs = tf.stop_gradient(ori_log_probs) if uda_confidence_thresh != -1: largest_prob = tf.reduce_max(tf.exp(ori_log_probs), axis=-1) unsup_loss_mask = tf.cast( tf.greater(largest_prob, uda_confidence_thresh), tf.float32) unsup_loss_mask = tf.stop_gradient(unsup_loss_mask) per_example_loss = kl_for_log_probs( tgt_ori_log_probs, aug_log_probs) * unsup_loss_mask if sample_weight is not None: unsup_sample_weight = tf.boolean_mask(sample_weight, mask=(1.0 - is_supervised), axis=0) per_example_loss *= tf.cast(unsup_sample_weight, dtype=tf.float32) unsup_loss = tf.reduce_mean(per_example_loss) self.losses['unsupervised'] = per_example_loss self.total_loss = sup_loss + unsup_loss
def dynamic_transformer_model(self, is_training, input_tensor, input_mask, batch_size, max_seq_length, label_size, attention_mask=None, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, intermediate_act_fn=util.gelu, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, initializer_range=0.02, dtype=tf.float32, cls_model='self-attention', cls_hidden_size=128, cls_num_attention_heads=2, speed=0.1, ignore_cls=None): if hidden_size % num_attention_heads != 0: raise ValueError( 'The hidden size (%d) is not a multiple of the number of ' 'attention heads (%d)' % (hidden_size, num_attention_heads)) attention_head_size = int(hidden_size / num_attention_heads) keep_cls = list(range(num_hidden_layers + 1)) keep_cls = [ cls_idx for cls_idx in keep_cls if cls_idx not in ignore_cls ] all_layer_outputs = [] all_layer_cls_outputs = collections.OrderedDict() prev_output = input_tensor prev_mask = input_mask for layer_idx in range(num_hidden_layers): with tf.variable_scope('layer_%d' % layer_idx): # build child classifier if is_training or layer_idx not in ignore_cls: with tf.variable_scope('distill'): # FCN + Self_Attention + FCN + FCN if cls_model == 'self-attention-paper': cls_output = self._cls_self_attention_paper( prev_output, batch_size, max_seq_length, label_size, attention_mask=attention_mask, cls_hidden_size=cls_hidden_size, cls_num_attention_heads=\ cls_num_attention_heads, attention_probs_dropout_prob=\ attention_probs_dropout_prob, initializer_range=initializer_range, dtype=tf.float32, trainable=True) # Self_Attention + FCN elif cls_model == 'self-attention': cls_output = self._cls_self_attention( prev_output, batch_size, max_seq_length, label_size, attention_mask=attention_mask, cls_hidden_size=cls_hidden_size, cls_num_attention_heads=\ cls_num_attention_heads, attention_probs_dropout_prob=\ attention_probs_dropout_prob, initializer_range=initializer_range, dtype=tf.float32, trainable=True) # FCN elif cls_model == 'fcn': cls_output = self._cls_fcn( prev_output, label_size, hidden_size=hidden_size, initializer_range=initializer_range, dtype=tf.float32, trainable=True) else: raise ValueError( 'Invalid `cls_model = %s`. Pick one from ' '`self-attention-paper`, `self-attention` ' 'and `fcn`' % cls_model) # distill core layer_cls_output = tf.nn.softmax(cls_output, axis=-1, name='cls_%d' % layer_idx) uncertainty = tf.reduce_sum(layer_cls_output * tf.log(layer_cls_output), axis=-1) uncertainty /= tf.log(1 / label_size) # branching only in inference if not is_training: # last output if layer_idx == keep_cls[-1]: all_layer_outputs.append(prev_output) all_layer_cls_outputs[layer_idx] = layer_cls_output return (all_layer_outputs, all_layer_cls_outputs) mask = tf.less(uncertainty, speed) unfinished_mask = \ (tf.ones_like(mask, dtype=dtype) - tf.cast(mask, dtype=dtype)) prev_output = tf.boolean_mask(prev_output, mask=unfinished_mask, axis=0) prev_mask = tf.boolean_mask(prev_mask, mask=unfinished_mask, axis=0) all_layer_cls_outputs[layer_idx] = layer_cls_output # new attention mask input_shape = util.get_shape_list(prev_output) batch_size = input_shape[0] max_seq_length = input_shape[1] attention_mask = \ self.create_attention_mask_from_input_mask( prev_mask, batch_size, max_seq_length, dtype=dtype) # originial stream with tf.variable_scope('attention'): attention_heads = [] with tf.variable_scope('self'): (attention_head, _) = self.attention_layer( from_tensor=prev_output, to_tensor=prev_output, attention_mask=attention_mask, num_attention_heads=num_attention_heads, size_per_head=attention_head_size, attention_probs_dropout_prob=\ attention_probs_dropout_prob, initializer_range=initializer_range, do_return_2d_tensor=False, batch_size=batch_size, from_max_seq_length=max_seq_length, to_max_seq_length=max_seq_length, dtype=dtype, trainable=False) attention_heads.append(attention_head) attention_output = None if len(attention_heads) == 1: attention_output = attention_heads[0] else: attention_output = tf.concat(attention_heads, axis=-1) with tf.variable_scope('output'): attention_output = tf.layers.dense( attention_output, hidden_size, kernel_initializer=util.create_initializer( initializer_range), trainable=False) attention_output = util.dropout( attention_output, hidden_dropout_prob) attention_output = util.layer_norm(attention_output + prev_output, trainable=False) # The activation is only applied to the `intermediate` # hidden layer. with tf.variable_scope('intermediate'): intermediate_output = tf.layers.dense( attention_output, intermediate_size, activation=intermediate_act_fn, kernel_initializer=util.create_initializer( initializer_range), trainable=False) # Down-project back to hidden_size then add the residual. with tf.variable_scope('output'): layer_output = tf.layers.dense( intermediate_output, hidden_size, kernel_initializer=util.create_initializer( initializer_range), trainable=False) layer_output = util.dropout(layer_output, hidden_dropout_prob) layer_output = util.layer_norm(layer_output + attention_output, trainable=False) prev_output = layer_output all_layer_outputs.append(layer_output) return (all_layer_outputs, all_layer_cls_outputs)