def get_classification_outputs(FLAGS, features, is_training): """Loss for downstream classification tasks.""" input_ids = features["input_ids"] seg_id = features["segment_ids"] input_mask_int = tf.cast(tf.cast(input_ids, tf.bool), tf.int32) input_mask = 1 - tf.cast(input_mask_int, tf.float32) num_choices = FLAGS.num_choices batch_size = tf.shape(features["input_ids"])[0] def _transform_features(feature): out = tf.reshape(feature, [batch_size, num_choices, -1]) out = tf.transpose(out, [2, 0, 1]) out = tf.reshape(out, [-1, batch_size * num_choices]) return out if num_choices: input_ids = _transform_features(input_ids) seg_id = _transform_features(seg_id) input_mask = _transform_features(input_mask) else: input_ids = tf.transpose(input_ids, [1, 0]) seg_id = tf.transpose(seg_id, [1, 0]) input_mask = tf.transpose(input_mask, [1, 0]) xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) run_config = xlnet.create_run_config(is_training, True, FLAGS) xlnet_model = xlnet.XLNetModel( xlnet_config=xlnet_config, run_config=run_config, input_ids=input_ids, seg_ids=seg_id, input_mask=input_mask) summary = xlnet_model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj) initializer = xlnet_model.get_initializer() return_dict = {} with tf.variable_scope("model", reuse=tf.AUTO_REUSE): with tf.variable_scope("answer_class"): # race has 4 classes, # boolq has 2 classes if num_choices: num_classes = 1 else: num_classes = FLAGS.num_classes cls_logits = tf.layers.dense(summary, num_classes, kernel_initializer=initializer, name="cls") if num_choices: cls_logits = tf.reshape(cls_logits, [batch_size, num_choices]) cls_log_probs = tf.nn.log_softmax(cls_logits, -1) if is_training: return_dict["cls_log_probs"] = cls_log_probs return_dict["cls_logits"] = cls_logits return return_dict
def relative_positional_encoding(qlen, klen, d_model, clamp_len, attn_type, bi_data=None, bsz=None, dtype=None): """create relative positional encoding.""" freq_seq = tf.range(0, d_model, 2.0) if dtype is not None and dtype != tf.float32: freq_seq = tf.cast(freq_seq, dtype=dtype) inv_freq = 1 / (10000**(freq_seq / d_model)) if attn_type == 'bi': # beg, end = klen - 1, -qlen beg, end = klen, -qlen elif attn_type == 'uni': # beg, end = klen - 1, -1 beg, end = klen, -1 else: raise ValueError('Unknown `attn_type` {}.'.format(attn_type)) # if bi_data: # fwd_pos_seq = tf.range(beg, end, -1.0) # bwd_pos_seq = tf.range(-beg, -end, 1.0) # # if dtype is not None and dtype != tf.float32: # fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype) # bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype) # # if clamp_len > 0: # fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len) # bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -clamp_len, clamp_len) # # if bsz is not None: # # With bi_data, the batch size should be divisible by 2. # assert bsz % 2 == 0 # fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz // 2) # bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq, bsz // 2) # else: # fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq) # bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq) # # pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1) # else: fwd_pos_seq = tf.range(beg, end, -1.0) if dtype is not None and dtype != tf.float32: fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype) if clamp_len > 0: fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len) pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz) return pos_emb
def _convert_example(example, use_bfloat16): """Cast int64 into int32 and float32 to bfloat16 if use_bfloat16.""" for key in list(example.keys()): val = example[key] if tf.keras.backend.is_sparse(val): val = tf.sparse.to_dense(val) if val.dtype == tf.int64: val = tf.cast(val, tf.int32) if use_bfloat16 and val.dtype == tf.float32: val = tf.cast(val, tf.bfloat16) example[key] = val
def get_attention_mask(input_ids, seq_len): # `non_tgt_mask` = [Seq_len, Seq_len] non_tgt_mask = -tf.eye(seq_len, dtype=tf.float32) # `input_mask` = [Seq_len, Batch_size] input_mask = 1 - tf.cast(tf.cast(input_ids, tf.bool), tf.float32) # `data_mask` = [1, Seq_len, Batch_size] data_mask = input_mask[None] # `attn_mask` = [1, Seq_len, Batch_size, 1] attn_mask = data_mask[:, :, :, None] # `non_tgt_mask` = [Seq_len, Seq_len, Batch_size, 1] non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=tf.float32) return non_tgt_mask
def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example.""" example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.cast(t, tf.int32) example[name] = t return example
def model_fn(features, labels, mode, params): #### Training or Evaluation is_training = (mode == tf.estimator.ModeKeys.TRAIN) total_loss, per_example_loss, logits = function_builder.get_race_loss( FLAGS, features, is_training) #### Check model parameters num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) logger.info('#params: {}'.format(num_params)) #### load pretrained models scaffold_fn = model_utils.init_from_checkpoint(FLAGS) #### Evaluation mode if mode == tf.estimator.ModeKeys.EVAL: assert FLAGS.num_hosts == 1 def metric_fn(per_example_loss, label_ids, logits, is_real_example): predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) eval_input_dict = { 'labels': label_ids, 'predictions': predictions, 'weights': is_real_example } accuracy = tf.metrics.accuracy(**eval_input_dict) loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) return { 'eval_accuracy': accuracy, 'eval_loss': loss} is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) #### Constucting evaluation TPUEstimatorSpec with new cache. label_ids = tf.reshape(features['label_ids'], [-1]) metric_args = [per_example_loss, label_ids, logits, is_real_example] if FLAGS.use_tpu: eval_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=(metric_fn, metric_args), scaffold_fn=scaffold_fn) else: eval_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, eval_metric_ops=metric_fn(*metric_args)) return eval_spec #### Configuring the optimizer train_op, learning_rate, _ = model_utils.get_train_op(FLAGS, total_loss) monitor_dict = {} monitor_dict["lr"] = learning_rate #### Constucting training TPUEstimatorSpec with new cache. if FLAGS.use_tpu: #### Creating host calls host_call = None train_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, host_call=host_call, scaffold_fn=scaffold_fn) else: train_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, train_op=train_op) return train_spec
def model_fn(features, labels, mode, params): # ### Training or Evaluation is_training = (mode == tf.estimator.ModeKeys.TRAIN) return_dict = function_builder.get_classification_outputs( FLAGS, features, is_training) # per_example_loss = return_dict["per_example_loss"] cls_logits = return_dict["cls_logits"] # ### Check model parameters num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) logger.info('#params: {}'.format(num_params)) # ### load pretrained models scaffold_fn = model_utils.init_from_checkpoint(FLAGS) if mode == tf.estimator.ModeKeys.PREDICT: # label_ids = tf.reshape(features["cls"], [-1]) predictions = { "feature_id": features["feature_id"], "cls_logits": cls_logits, # "cls": label_ids, } if FLAGS.use_tpu: output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) else: output_spec = tf.estimator.EstimatorSpec( mode=mode, predictions=predictions) return output_spec def compute_loss(log_probs, positions, depth): one_hot_positions = tf.one_hot(positions, depth=depth, dtype=tf.float32) loss = -tf.reduce_sum(one_hot_positions * log_probs, axis=-1) loss = tf.reduce_mean(loss) return loss cls_log_probs = return_dict["cls_log_probs"] num_choices = FLAGS.num_choices if num_choices: num_classes = num_choices else: num_classes = FLAGS.num_classes total_loss = compute_loss(cls_log_probs, features["cls"], depth=num_classes) # ### Configuring the optimizer train_op, learning_rate, _ = model_utils.get_train_op( FLAGS, total_loss) monitor_dict = {'loss/cls': total_loss, "lr": learning_rate} # ### Constucting training TPUEstimatorSpec with new cache. if FLAGS.use_tpu: # ### Creating host calls if not FLAGS.is_regression: label_ids = tf.reshape(features['cls'], [-1]) predictions = tf.argmax(cls_logits, axis=-1, output_type=label_ids.dtype) is_correct = tf.equal(predictions, label_ids) accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32)) monitor_dict["accuracy"] = accuracy host_call = function_builder.construct_scalar_host_call( monitor_dict=monitor_dict, model_dir=FLAGS.model_dir, prefix="train/", reduce_fn=tf.reduce_mean) else: host_call = None train_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, host_call=host_call, scaffold_fn=scaffold_fn) else: train_spec = tf.estimator.EstimatorSpec(mode=mode, loss=total_loss, train_op=train_op) return train_spec
def get_train_op(FLAGS, total_loss, grads_and_vars=None, trainable_variables=None): global_step = tf.train.get_or_create_global_step() # increase the learning rate linearly if FLAGS.warmup_steps > 0: warmup_lr = (tf.cast(global_step, tf.float32) / tf.cast(FLAGS.warmup_steps, tf.float32) * FLAGS.learning_rate) else: warmup_lr = 0.0 # decay the learning rate if FLAGS.decay_method == "poly": decay_lr = tf.train.polynomial_decay( FLAGS.learning_rate, global_step=global_step - FLAGS.warmup_steps, decay_steps=FLAGS.train_steps - FLAGS.warmup_steps, end_learning_rate=FLAGS.learning_rate * FLAGS.min_lr_ratio) elif FLAGS.decay_method == "cos": decay_lr = tf.train.cosine_decay( FLAGS.learning_rate, global_step=global_step - FLAGS.warmup_steps, decay_steps=FLAGS.train_steps - FLAGS.warmup_steps, alpha=FLAGS.min_lr_ratio) else: raise ValueError(FLAGS.decay_method) learning_rate = tf.where(global_step < FLAGS.warmup_steps, warmup_lr, decay_lr) if (FLAGS.weight_decay > 0 and not FLAGS.use_tpu and FLAGS.num_core_per_host > 1): raise ValueError("Do not support `weight_decay > 0` with multi-gpu " "training so far.") if FLAGS.weight_decay == 0: optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=FLAGS.adam_epsilon) else: optimizer = AdamWeightDecayOptimizer( learning_rate=learning_rate, epsilon=FLAGS.adam_epsilon, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], weight_decay_rate=FLAGS.weight_decay) if FLAGS.use_tpu: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) if grads_and_vars is None: grads_and_vars = optimizer.compute_gradients( total_loss, var_list=trainable_variables) gradients, variables = zip(*grads_and_vars) clipped, gnorm = tf.clip_by_global_norm(gradients, FLAGS.clip) if getattr(FLAGS, "lr_layer_decay_rate", 1.0) != 1.0: n_layer = 0 for i in range(len(clipped)): m = re.search(r"model/transformer/layer_(\d+?)/", variables[i].name) if not m: continue n_layer = max(n_layer, int(m.group(1)) + 1) for i in range(len(clipped)): for l in range(n_layer): if "model/transformer/layer_{}/".format( l) in variables[i].name: abs_rate = FLAGS.lr_layer_decay_rate**(n_layer - 1 - l) clipped[i] *= abs_rate logger.info( "Apply mult {:.4f} to layer-{} grad of {}".format( abs_rate, l, variables[i].name)) break train_op = optimizer.apply_gradients(zip(clipped, variables), global_step=global_step) # Manually increment `global_step` for AdamWeightDecayOptimizer if FLAGS.weight_decay > 0: new_global_step = global_step + 1 train_op = tf.group(train_op, [global_step.assign(new_global_step)]) return train_op, learning_rate, gnorm
def transformer_xl( input_ids, n_token, n_layer, d_model, n_head, d_head, d_inner, dropout, dropatt, attn_type, is_training, initializer, # bi_data, mem_len=None, # inp_q=None, mems=None, same_length=False, clamp_len=-1, untie_r=False, use_tpu=True, input_mask=None, seg_id=None, # perm_mask=None, reuse_len=None, target_mapping=None, ff_activation='relu', use_bfloat16=False, scope='transformer', **kwargs): """ Defines a Transformer-XL computation graph with additional support for XLNet. Args: input_ids: int32 Tensor in shape [len, bsz], the input token IDs. seg_id: int32 Tensor in shape [len, bsz], the input segment IDs. input_mask: float32 Tensor in shape [len, bsz], the input mask. 0 for real tokens and 1 for padding. mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory from previous batches. The length of the list equals n_layer. If None, no memory is used. perm_mask: float32 Tensor in shape [len, len, bsz]. If perm_mask[i, j, k] = 0, i attend to j in batch k; if perm_mask[i, j, k] = 1, i does not attend to j in batch k. If None, each position attends to all the others. target_mapping: float32 Tensor in shape [num_predict, len, bsz]. If target_mapping[i, j, k] = 1, the i-th predict in batch k is on the j-th token. Only used during pretraining for partial prediction. Set to None during finetuning. inp_q: float32 Tensor in shape [len, bsz]. 1 for tokens with losses and 0 for tokens without losses. Only used during pretraining for two-stream attention. Set to None during finetuning. n_layer: int, the number of layers. d_model: int, the hidden size. n_head: int, the number of attention heads. d_head: int, the dimension size of each attention head. d_inner: int, the hidden size in feed-forward layers. ff_activation: str, "relu" or "gelu". untie_r: bool, whether to untie the biases in attention. n_token: int, the vocab size. is_training: bool, whether in training mode. use_tpu: bool, whether TPUs are used. use_bfloat16: bool, use bfloat16 instead of float32. dropout: float, dropout rate. dropatt: float, dropout rate on attention probabilities. init: str, the initialization scheme, either "normal" or "uniform". init_range: float, initialize the parameters with a uniform distribution in [-init_range, init_range]. Only effective when init="uniform". init_std: float, initialize the parameters with a normal distribution with mean 0 and stddev init_std. Only effective when init="normal". mem_len: int, the number of tokens to cache. reuse_len: int, the number of tokens in the currect batch to be cached and reused in the future. bi_data: bool, whether to use bidirectional input pipeline. Usually set to True during pretraining and False during finetuning. clamp_len: int, clamp all relative distances larger than clamp_len. -1 means no clamping. same_length: bool, whether to use the same attention length for each token. summary_type: str, "last", "first", "mean", or "attn". The method to pool the input to get a vector representation. initializer: A tf initializer. scope: scope name for the computation graph. """ # logger.info('memory input {}'.format(mems)) tf_float = tf.bfloat16 if use_bfloat16 else tf.float32 logger.info('Use float type {}'.format(tf_float)) new_mems = [] with tf.variable_scope(scope): if untie_r: r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) batch_size = tf.shape(input_ids)[1] seq_len = tf.shape(input_ids)[0] # mlen = tf.shape(mems[0])[0] if mems is not None else 0 mlen = 0 klen = mlen + seq_len # #### Attention mask attn_mask = None # causal attention mask # if attn_type == 'uni': # attn_mask = _create_mask(seq_len, mlen, tf_float, same_length) # attn_mask = attn_mask[:, :, None, None] # elif attn_type == 'bi': # attn_mask = None # else: # raise ValueError('Unsupported attention type: {}'.format(attn_type)) # data mask: input mask & perm mask data_mask = input_mask[None] # if input_mask is not None and perm_mask is not None: # data_mask = input_mask[None] + perm_mask # elif input_mask is not None and perm_mask is None: # data_mask = input_mask[None] # elif input_mask is None and perm_mask is not None: # data_mask = perm_mask # else: # data_mask = None if data_mask is not None: # all mems can be attended to mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, batch_size], dtype=tf_float) data_mask = tf.concat([mems_mask, data_mask], 1) if attn_mask is None: attn_mask = data_mask[:, :, :, None] else: attn_mask += data_mask[:, :, :, None] if attn_mask is not None: attn_mask = tf.cast(attn_mask > 0, dtype=tf_float) if attn_mask is not None: non_tgt_mask = -tf.eye(seq_len, dtype=tf_float) non_tgt_mask = tf.concat( [tf.zeros([seq_len, mlen], dtype=tf_float), non_tgt_mask], axis=-1) non_tgt_mask = tf.cast( (attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=tf_float) else: non_tgt_mask = None # #### Word embedding word_emb_k, lookup_table = embedding_lookup(x=input_ids, n_token=n_token, d_embed=d_model, initializer=initializer, use_tpu=use_tpu, dtype=tf_float, scope='word_embedding') # if inp_q is not None: # with tf.variable_scope('mask_emb'): # mask_emb = tf.get_variable('mask_emb', [1, 1, d_model], # dtype=tf_float) # if target_mapping is not None: # word_emb_q = tf.tile(mask_emb, [tf.shape(target_mapping)[0], # batch_size, 1]) # else: # inp_q_ext = inp_q[:, :, None] # word_emb_q = inp_q_ext * mask_emb + ( # 1 - inp_q_ext) * word_emb_k output_h = tf.layers.dropout(word_emb_k, dropout, training=is_training) # if inp_q is not None: # output_g = tf.layers.dropout(word_emb_q, dropout, # training=is_training) # #### Segment embedding if seg_id is not None: if untie_r: r_s_bias = tf.get_variable('r_s_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: # default case (tie) r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) seg_embed = tf.get_variable('seg_embed', [n_layer, 2, n_head, d_head], dtype=tf_float, initializer=initializer) # Convert `seg_id` to one-hot `seg_mat` mem_pad = tf.zeros([mlen, batch_size], dtype=tf.int32) cat_ids = tf.concat([mem_pad, seg_id], 0) # `1` indicates not in the same segment [qlen x klen x bsz] seg_mat = tf.cast( tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])), tf.int32) seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float) else: seg_mat = None # #### Positional encoding pos_emb = relative_positional_encoding(seq_len, klen, d_model, clamp_len, attn_type, bsz=batch_size, dtype=tf_float) pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training) # #### Attention layers # if mems is None: # mems = [None] * n_layer mems = [None] * n_layer for i in range(n_layer): # cache new mems # new_mems.append(_cache_mem(output_h, mems[i], mem_len, reuse_len)) new_mems.append(None) # segment bias if seg_id is None: r_s_bias_i = None seg_embed_i = None else: r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i] seg_embed_i = seg_embed[i] with tf.variable_scope('layer_{}'.format(i)): # if inp_q is not None: # output_h, output_g = two_stream_rel_attn( # h=output_h, # g=output_g, # r=pos_emb, # r_w_bias=r_w_bias if not untie_r else r_w_bias[i], # r_r_bias=r_r_bias if not untie_r else r_r_bias[i], # seg_mat=seg_mat, # r_s_bias=r_s_bias_i, # seg_embed=seg_embed_i, # attn_mask_h=non_tgt_mask, # attn_mask_g=attn_mask, # mems=mems[i], # target_mapping=target_mapping, # d_model=d_model, # n_head=n_head, # d_head=d_head, # dropout=dropout, # dropatt=dropatt, # is_training=is_training, # kernel_initializer=initializer) # reuse = True # else: reuse = False output_h = rel_multihead_attn( h=output_h, r=pos_emb, r_w_bias=r_w_bias if not untie_r else r_w_bias[i], r_r_bias=r_r_bias if not untie_r else r_r_bias[i], seg_mat=seg_mat, r_s_bias=r_s_bias_i, seg_embed=seg_embed_i, attn_mask=non_tgt_mask, mems=mems[i], d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer, reuse=reuse) # if inp_q is not None: # output_g = positionwise_ffn( # inp=output_g, # d_model=d_model, # d_inner=d_inner, # dropout=dropout, # kernel_initializer=initializer, # activation_type=ff_activation, # is_training=is_training) output_h = positionwise_ffn(inp=output_h, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training, reuse=reuse) # if inp_q is not None: # output = tf.layers.dropout(output_g, dropout, training=is_training) # else: # output = tf.layers.dropout(output_h, dropout, training=is_training) output = tf.layers.dropout(output_h, dropout, training=is_training) return output, new_mems, lookup_table
def parser(record): """function used to parse tfrecord.""" record_spec = { "input": tf.FixedLenFeature([seq_len], tf.int64), "target": tf.FixedLenFeature([seq_len], tf.int64), "seg_id": tf.FixedLenFeature([seq_len], tf.int64), "label": tf.FixedLenFeature([1], tf.int64), "is_masked": tf.FixedLenFeature([seq_len], tf.int64), } # retrieve serialized example example = tf.parse_single_example( serialized=record, features=record_spec) inputs = example.pop("input") target = example.pop("target") is_masked = tf.cast(example.pop("is_masked"), tf.bool) non_reuse_len = seq_len - reuse_len assert perm_size <= reuse_len and perm_size <= non_reuse_len perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm( inputs[:reuse_len], target[:reuse_len], is_masked[:reuse_len], perm_size, reuse_len) perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm( inputs[reuse_len:], target[reuse_len:], is_masked[reuse_len:], perm_size, non_reuse_len) perm_mask_0 = tf.concat( [perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1) perm_mask_1 = tf.concat( [tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], axis=1) perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0) target = tf.concat([target_0, target_1], axis=0) target_mask = tf.concat([target_mask_0, target_mask_1], axis=0) input_k = tf.concat([input_k_0, input_k_1], axis=0) input_q = tf.concat([input_q_0, input_q_1], axis=0) if num_predict is not None: indices = tf.range(seq_len, dtype=tf.int64) bool_target_mask = tf.cast(target_mask, tf.bool) indices = tf.boolean_mask(indices, bool_target_mask) ##### extra padding due to CLS/SEP introduced after prepro actual_num_predict = tf.shape(indices)[0] pad_len = num_predict - actual_num_predict ##### target_mapping target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32) paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype) target_mapping = tf.concat([target_mapping, paddings], axis=0) example["target_mapping"] = tf.reshape(target_mapping, [num_predict, seq_len]) ##### target target = tf.boolean_mask(target, bool_target_mask) paddings = tf.zeros([pad_len], dtype=target.dtype) target = tf.concat([target, paddings], axis=0) example["target"] = tf.reshape(target, [num_predict]) ##### target mask target_mask = tf.concat( [tf.ones([actual_num_predict], dtype=tf.float32), tf.zeros([pad_len], dtype=tf.float32)], axis=0) example["target_mask"] = tf.reshape(target_mask, [num_predict]) else: example["target"] = tf.reshape(target, [seq_len]) example["target_mask"] = tf.reshape(target_mask, [seq_len]) # reshape back to fixed shape example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len]) example["input_k"] = tf.reshape(input_k, [seq_len]) example["input_q"] = tf.reshape(input_q, [seq_len]) _convert_example(example, use_bfloat16) for k, v in example.items(): logger.info("%s: %s", k, v) return example
def _local_perm(inputs, targets, is_masked, perm_size, seq_len): """ Sample a permutation of the factorization order, and create an attention mask accordingly. Args: inputs: int64 Tensor in shape [seq_len], input ids. targets: int64 Tensor in shape [seq_len], target ids. is_masked: bool Tensor in shape [seq_len]. True means being selected for partial prediction. perm_size: the length of longest permutation. Could be set to be reuse_len. Should not be larger than reuse_len or there will be data leaks. seq_len: int, sequence length. """ # Generate permutation indices index = tf.range(seq_len, dtype=tf.int64) index = tf.transpose(tf.reshape(index, [-1, perm_size])) index = tf.random_shuffle(index) index = tf.reshape(tf.transpose(index), [-1]) # `perm_mask` and `target_mask` # non-functional tokens non_func_tokens = tf.logical_not(tf.logical_or( tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID))) non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens) masked_or_func_tokens = tf.logical_not(non_mask_tokens) # Set the permutation indices of non-masked (& non-funcional) tokens to the # smallest index (-1): # (1) they can be seen by all other positions # (2) they cannot see masked positions, so there won"t be information leak smallest_index = -tf.ones([seq_len], dtype=tf.int64) rev_index = tf.where(non_mask_tokens, smallest_index, index) # Create `target_mask`: non-funcional and maksed tokens # 1: use mask as input and have loss # 0: use token (or [SEP], [CLS]) as input and do not have loss target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens) target_mask = tf.cast(target_tokens, tf.float32) # Create `perm_mask` # `target_tokens` cannot see themselves self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1) # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens) # 0: can attend if i > j or j is non-masked perm_mask = tf.logical_and( self_rev_index[:, None] <= rev_index[None, :], masked_or_func_tokens) perm_mask = tf.cast(perm_mask, tf.float32) # new target: [next token] for LM and [curr token] (self) for PLM new_targets = tf.concat([inputs[0: 1], targets[: -1]], axis=0) # construct inputs_k inputs_k = inputs # construct inputs_q inputs_q = target_mask return perm_mask, new_targets, target_mask, inputs_k, inputs_q
def get_decomposed_qa_outputs(FLAGS, features, is_training): question_ids = features["question_ids"] context_ids = features["context_ids"] seq_len = FLAGS.max_seq_length q_seq_len = FLAGS.max_first_length + 2 ctx_seq_len = seq_len - q_seq_len q_mask_int = tf.cast(tf.cast(question_ids, tf.bool), tf.int32) cls_index = tf.reshape( tf.reduce_sum(q_mask_int, axis=1) + ctx_seq_len, [-1]) # 0 for mask out # q_zeros = tf.zeros_like(question_ids) # p_ids = tf.concat([context_ids, q_zeros], axis=1) # p_mask = tf.cast(tf.cast(p_ids, tf.bool), tf.float32) question_ids = tf.transpose(question_ids, [1, 0]) context_ids = tf.transpose(context_ids, [1, 0]) q_attn_mask = get_attention_mask(question_ids, q_seq_len) c_attn_mask = get_attention_mask(context_ids, ctx_seq_len) qc_attn_mask = get_attention_mask( tf.concat([context_ids, question_ids], axis=0), seq_len) xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) run_config = xlnet.create_run_config(is_training, True, FLAGS) initializer = xlnet._get_initializer(run_config) tfm_args = dict( n_token=xlnet_config.n_token, initializer=initializer, attn_type="bi", n_layer=xlnet_config.n_layer, d_model=xlnet_config.d_model, n_head=xlnet_config.n_head, d_head=xlnet_config.d_head, d_inner=xlnet_config.d_inner, ff_activation=xlnet_config.ff_activation, untie_r=xlnet_config.untie_r, is_training=run_config.is_training, use_bfloat16=run_config.use_bfloat16, use_tpu=run_config.use_tpu, dropout=run_config.dropout, dropatt=run_config.dropatt, # mem_len=run_config.mem_len, # reuse_len=run_config.reuse_len, # bi_data=run_config.bi_data, clamp_len=run_config.clamp_len, # same_length=run_config.same_length, ctx_ids=context_ids, q_ids=question_ids, q_seq_len=q_seq_len, ctx_seq_len=ctx_seq_len, sep_layer=FLAGS.sep_layer, q_attn_mask=q_attn_mask, c_attn_mask=c_attn_mask, qc_attn_mask=qc_attn_mask, ) with tf.variable_scope("model", reuse=tf.AUTO_REUSE): upper_outputs = transformer_xl_decomposed(**tfm_args) output = upper_outputs[-1] return_dict = {'upper_outputs': upper_outputs} with tf.variable_scope("logits"): # logits: seq, batch_size, 2 logits = tf.layers.dense(output, 2, kernel_initializer=initializer) # logits: 2, batch_size, seq logits = tf.transpose(logits, [2, 1, 0]) # start_logits: batch_size, seq # end_logits: batch_size, seq start_logits, end_logits = tf.unstack(logits, axis=0) # start_logits_masked = start_logits * p_mask - 1e30 * (1 - p_mask) # start_log_probs = tf.nn.log_softmax(start_logits_masked, -1) start_log_probs = tf.nn.log_softmax(start_logits, -1) # end_logits_masked = end_logits * p_mask - 1e30 * (1 - p_mask) # end_log_probs = tf.nn.log_softmax(end_logits_masked, -1) end_log_probs = tf.nn.log_softmax(end_logits, -1) return_dict["start_logits"] = start_logits return_dict["end_logits"] = end_logits if is_training: return_dict["start_log_probs"] = start_log_probs return_dict["end_log_probs"] = end_log_probs # an additional layer to predict answer class, 0: span, 1:yes, 2:no with tf.variable_scope("answer_class"): # get the representation of CLS cls_index = tf.one_hot(cls_index, seq_len, axis=-1, dtype=tf.float32) cls_feature = tf.einsum("lbh,bl->bh", output, cls_index) ans_feature = tf.layers.dense(cls_feature, xlnet_config.d_model, activation=tf.tanh, kernel_initializer=initializer, name='pooler') ans_feature = tf.layers.dropout(ans_feature, FLAGS.dropout, training=is_training) # hotpot has 3 classes, # squad 2.0 has 2 classes cls_logits = tf.layers.dense(ans_feature, FLAGS.num_classes, kernel_initializer=initializer, name="cls") cls_log_probs = tf.nn.log_softmax(cls_logits, -1) return_dict["cls_logits"] = cls_logits if is_training: return_dict["cls_log_probs"] = cls_log_probs return return_dict
def transformer_xl_decomposed(n_token, n_layer, d_model, n_head, d_head, d_inner, dropout, dropatt, attn_type, is_training, initializer, q_ids, ctx_ids, clamp_len=-1, untie_r=False, use_tpu=True, ff_activation='relu', use_bfloat16=False, sep_layer=9, q_attn_mask=None, c_attn_mask=None, qc_attn_mask=None, q_seq_len=None, ctx_seq_len=None, scope='transformer', **kwargs): tf_float = tf.bfloat16 if use_bfloat16 else tf.float32 logger.info('Use float type {}'.format(tf_float)) # new_mems = [] with tf.variable_scope(scope): if untie_r: r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) # batch_size = tf.shape(input_ids)[1] # seq_len = tf.shape(input_ids)[0] batch_size = tf.shape(q_ids)[1] # mlen = tf.shape(mems[0])[0] if mems is not None else 0 # mlen = 0 # klen = mlen + seq_len # #### Attention mask attn_mask = None # data_mask = input_mask[None] # if data_mask is not None: # all mems can be attended to # mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, batch_size], # dtype=tf_float) # data_mask = tf.concat([mems_mask, data_mask], 1) # if attn_mask is None: # attn_mask = data_mask[:, :, :, None] # else: # attn_mask += data_mask[:, :, :, None] # non_tgt_mask = None # #### Word embedding q_emb, lookup_table = embedding_lookup(x=q_ids, n_token=n_token, d_embed=d_model, initializer=initializer, use_tpu=use_tpu, dtype=tf_float, scope='word_embedding') c_emb, _ = embedding_lookup(x=ctx_ids, n_token=n_token, d_embed=d_model, initializer=initializer, use_tpu=use_tpu, dtype=tf_float, reuse=True, scope='word_embedding') q_output_h = tf.layers.dropout(q_emb, dropout, training=is_training) ctx_output_h = tf.layers.dropout(c_emb, dropout, training=is_training) # #### Segment embedding if untie_r: r_s_bias = tf.get_variable('r_s_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: # default case (tie) r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) seg_embed = tf.get_variable('seg_embed', [n_layer, 2, n_head, d_head], dtype=tf_float, initializer=initializer) # Convert `seg_id` to one-hot `seg_mat` # mem_pad = tf.zeros([mlen, batch_size], dtype=tf.int32) # cat_ids = tf.concat([mem_pad, seg_id], 0) # `1` indicates not in the same segment [qlen x klen x bsz] ctx_seg_ids = tf.zeros_like(ctx_ids, dtype=tf.int32) ctx_seg_mat = tf.cast( tf.logical_not(tf.equal(ctx_seg_ids[:, None], ctx_seg_ids[None, :])), tf.int32) ctx_seg_mat = tf.one_hot(ctx_seg_mat, 2, dtype=tf_float) q_seg_ids = tf.ones_like(q_ids, dtype=tf.int32) q_seg_mat = tf.cast( tf.logical_not(tf.equal(q_seg_ids[:, None], q_seg_ids[None, :])), tf.int32) q_seg_mat = tf.one_hot(q_seg_mat, 2, dtype=tf_float) seg_ids = tf.concat([ctx_seg_ids, q_seg_ids], axis=0) seg_mat = tf.cast( tf.logical_not(tf.equal(seg_ids[:, None], seg_ids[None, :])), tf.int32) seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float) # #### Positional encoding FIXME: better use of relative pos emb q_pos_emb = relative_positional_encoding(q_seq_len, q_seq_len, d_model, clamp_len, attn_type, bsz=batch_size, dtype=tf_float) q_pos_emb = tf.layers.dropout(q_pos_emb, dropout, training=is_training) ctx_pos_emb = relative_positional_encoding(ctx_seq_len, ctx_seq_len, d_model, clamp_len, attn_type, bsz=batch_size, dtype=tf_float) ctx_pos_emb = tf.layers.dropout(ctx_pos_emb, dropout, training=is_training) # pos_emb = tf.concat([ctx_pos_emb, q_pos_emb], axis=0) seq_len = ctx_seq_len + q_seq_len pos_emb = relative_positional_encoding(seq_len, seq_len, d_model, clamp_len, attn_type, bsz=batch_size, dtype=tf_float) pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training) # ctx_pos_emb = pos_emb[q_seq_len:q_seq_len + 2 * ctx_seq_len, :, :] # q_pos_emb1 = pos_emb[:q_seq_len, :, :] # q_pos_emb2 = pos_emb[q_seq_len + 2 * ctx_seq_len:, :, :] # q_pos_emb = tf.concat([q_pos_emb1, q_pos_emb2], axis=0) # #### Attention layers # mems = [None] * n_layer for i in range(sep_layer): r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i] r_w_bias_i = r_w_bias if not untie_r else r_w_bias[i] r_r_bias_i = r_r_bias if not untie_r else r_r_bias[i] seg_embed_i = seg_embed[i] with tf.variable_scope('layer_{}'.format(i)): ctx_output_h = rel_multihead_attn( h=ctx_output_h, r=ctx_pos_emb, r_w_bias=r_w_bias_i, r_r_bias=r_r_bias_i, r_s_bias=r_s_bias_i, seg_mat=ctx_seg_mat, seg_embed=seg_embed_i, attn_mask=c_attn_mask, mems=None, d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer, reuse=False) ctx_output_h = positionwise_ffn(inp=ctx_output_h, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training, reuse=False) q_output_h = rel_multihead_attn(h=q_output_h, r=q_pos_emb, r_w_bias=r_w_bias_i, r_r_bias=r_r_bias_i, r_s_bias=r_s_bias_i, seg_mat=q_seg_mat, seg_embed=seg_embed_i, attn_mask=q_attn_mask, mems=None, d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer, reuse=tf.AUTO_REUSE) q_output_h = positionwise_ffn(inp=q_output_h, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training, reuse=tf.AUTO_REUSE) # concat all q, ctx related variables output_h = tf.concat([ctx_output_h, q_output_h], axis=0) upper_outputs = [] for i in range(sep_layer, n_layer): r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i] r_w_bias_i = r_w_bias if not untie_r else r_w_bias[i] r_r_bias_i = r_r_bias if not untie_r else r_r_bias[i] seg_embed_i = seg_embed[i] with tf.variable_scope('layer_{}'.format(i)): output_h = rel_multihead_attn(h=output_h, r=pos_emb, seg_mat=seg_mat, r_w_bias=r_w_bias_i, r_r_bias=r_r_bias_i, r_s_bias=r_s_bias_i, seg_embed=seg_embed_i, attn_mask=qc_attn_mask, mems=None, d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer, reuse=False) output_h = positionwise_ffn(inp=output_h, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training, reuse=False) upper_outputs.append(output_h) output = tf.layers.dropout(output_h, dropout, training=is_training) upper_outputs[-1] = output return upper_outputs
def get_qa_outputs(FLAGS, features, is_training): """Loss for downstream span-extraction QA tasks such as SQuAD.""" input_ids = features["input_ids"] seg_id = features["segment_ids"] input_mask_int = tf.cast(tf.cast(input_ids, tf.bool), tf.int32) cls_index = tf.reshape(tf.reduce_sum(input_mask_int, axis=1), [-1]) p_mask = tf.cast(tf.cast(seg_id, tf.bool), tf.float32) input_ids = tf.transpose(input_ids, [1, 0]) input_mask = 1 - tf.cast(input_mask_int, tf.float32) input_mask = tf.transpose(input_mask, [1, 0]) seg_id = tf.transpose(seg_id, [1, 0]) seq_len = tf.shape(input_ids)[0] xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) run_config = xlnet.create_run_config(is_training, True, FLAGS) xlnet_model = xlnet.XLNetModel( xlnet_config=xlnet_config, run_config=run_config, input_ids=input_ids, seg_ids=seg_id, input_mask=input_mask) output = xlnet_model.get_sequence_output() initializer = xlnet_model.get_initializer() return_dict = {} with tf.variable_scope("logits"): # logits: seq, batch_size, 2 logits = tf.layers.dense(output, 2, kernel_initializer=initializer) # logits: 2, batch_size, seq logits = tf.transpose(logits, [2, 1, 0]) # start_logits: batch_size, seq # end_logits: batch_size, seq start_logits, end_logits = tf.unstack(logits, axis=0) start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask start_log_probs = tf.nn.log_softmax(start_logits_masked, -1) end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask end_log_probs = tf.nn.log_softmax(end_logits_masked, -1) if is_training: return_dict["start_log_probs"] = start_log_probs return_dict["end_log_probs"] = end_log_probs else: return_dict["start_logits"] = start_logits return_dict["end_logits"] = end_logits # an additional layer to predict answer class, 0: span, 1:yes, 2:no with tf.variable_scope("answer_class"): # get the representation of CLS cls_index = tf.one_hot(cls_index, seq_len, axis=-1, dtype=tf.float32) cls_feature = tf.einsum("lbh,bl->bh", output, cls_index) ans_feature = tf.layers.dense(cls_feature, xlnet_config.d_model, activation=tf.tanh, kernel_initializer=initializer, name='pooler') ans_feature = tf.layers.dropout(ans_feature, FLAGS.dropout, training=is_training) # hotpot has 3 classes, # squad 2.0 has 2 classes cls_logits = tf.layers.dense(ans_feature, FLAGS.num_classes, kernel_initializer=initializer, name="cls") cls_log_probs = tf.nn.log_softmax(cls_logits, -1) if is_training: return_dict["cls_log_probs"] = cls_log_probs return_dict["cls_logits"] = cls_logits return return_dict