def metric_fn(per_example_loss, label_ids, logits, is_real_example): """Compute Matthew's correlations for STS-B.""" predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) # https://en.wikipedia.org/wiki/Matthews_correlation_coefficient tp, tp_op = tf.metrics.true_positives( predictions, label_ids, weights=is_real_example) tn, tn_op = tf.metrics.true_negatives( predictions, label_ids, weights=is_real_example) fp, fp_op = tf.metrics.false_positives( predictions, label_ids, weights=is_real_example) fn, fn_op = tf.metrics.false_negatives( predictions, label_ids, weights=is_real_example) # Compute Matthew's correlation mcc = tf.div_no_nan( tp * tn - fp * fn, tf.pow((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn), 0.5)) # Compute accuracy accuracy = tf.metrics.accuracy( labels=label_ids, predictions=predictions, weights=is_real_example) loss = tf.metrics.mean( values=per_example_loss, weights=is_real_example) return {"matthew_corr": (mcc, tf.group(tp_op, tn_op, fp_op, fn_op)), "eval_accuracy": accuracy, "eval_loss": loss,}
def f1_score(answer_ids, prediction_ids, vocab): """Compute F1 score between answer tokens and prediction tokens. Args: answer_ids: <int32> [batch_size, answer_length] prediction_ids: <int32> [batch_size, prediction_length] vocab: Instance of text_utils.Vocab. Returns: score: <float32> [batch_size] tensor of [0.0, 1.0]. """ # Order insensitive, so we just create a vocabulary sized bit tensor where # the vocabulary items that are not to be counted are masked out. vocab_size = len(vocab) remove_ids = list(_get_normalized_set(vocab)) remove_mask = tf.expand_dims(tf.one_hot(remove_ids, vocab_size), 0) remove_mask = tf.reduce_sum(remove_mask, axis=1) remove_mask = tf.cast(tf.equal(remove_mask, 0), tf.float32) # [batch_size, vocab_size] answer_ids = tf.reduce_sum(tf.one_hot(answer_ids, vocab_size), axis=1) answer_ids *= remove_mask # [batch_size, vocab_size] prediction_ids = tf.reduce_sum(tf.one_hot(prediction_ids, vocab_size), axis=1) prediction_ids *= remove_mask # Compute multiset intersection, and count the size. intersection = tf.minimum(prediction_ids, answer_ids) intersection = tf.reduce_sum(intersection, axis=1) # Compute F1 score: # Re(A, B) = |A \cap B| / |B| # Pr(A, B) = |A \cap B| / |A| # F1(A, B) = 2 * (Pr * Re) / (Pr + Re) recall = tf.div_no_nan(intersection, tf.reduce_sum(answer_ids, axis=1)) precision = tf.div_no_nan(intersection, tf.reduce_sum(prediction_ids, axis=1)) score = 2 * tf.div_no_nan(precision * recall, precision + recall) return score
def _build_model(self): self.graph_built = True tf.set_random_seed(self.seed) self.user_indices = tf.placeholder(tf.int32, shape=[None]) self.item_indices = tf.placeholder(tf.int32, shape=[None]) self.user_interacted_seq = tf.placeholder( tf.int32, shape=[None, self.interaction_num]) self.user_interacted_len = tf.placeholder(tf.float32, shape=[None]) self.labels = tf.placeholder(tf.float32, shape=[None]) self.is_training = tf.placeholder_with_default(False, shape=[]) self.concat_embed = [] user_features = tf.get_variable( name="user_features", shape=[self.n_users + 1, self.embed_size], initializer=tf_truncated_normal(0.0, 0.01), regularizer=self.reg) item_features = tf.get_variable( name="item_features", shape=[self.n_items + 1, self.embed_size], initializer=tf_truncated_normal(0.0, 0.01), regularizer=self.reg) user_embed = tf.nn.embedding_lookup(user_features, self.user_indices) item_embed = tf.nn.embedding_lookup(item_features, self.item_indices) # unknown items are padded to 0-vector zero_padding_op = tf.scatter_update( item_features, self.n_items, tf.zeros([self.embed_size], dtype=tf.float32)) with tf.control_dependencies([zero_padding_op]): multi_item_embed = tf.nn.embedding_lookup( item_features, self.user_interacted_seq) # B * seq * K pooled_embed = tf.div_no_nan( tf.reduce_sum(multi_item_embed, axis=1), tf.expand_dims(tf.sqrt(self.user_interacted_len), axis=1)) self.concat_embed.extend([user_embed, item_embed, pooled_embed]) if self.sparse: self._build_sparse() if self.dense: self._build_dense() concat_embed = tf.concat(self.concat_embed, axis=1) mlp_layer = dense_nn(concat_embed, self.hidden_units, use_bn=self.use_bn, dropout_rate=self.dropout_rate, is_training=self.is_training) self.output = tf.reshape(tf.layers.dense(inputs=mlp_layer, units=1), [-1]) count_params()
def _attention_unit(self, queries, keys, keys_len): if self.use_tf_attention: query_masks = tf.cast( tf.ones_like(tf.reshape(self.user_interacted_len, [-1, 1])), dtype=tf.bool ) key_masks = tf.sequence_mask( self.user_interacted_len, self.max_seq_len ) queries = tf.expand_dims(queries, axis=1) attention = tf.keras.layers.Attention(use_scale=False) pooled_outputs = attention(inputs=[queries, keys], mask=[query_masks, key_masks]) return pooled_outputs else: # queries: B * K, keys: B * seq * K queries = tf.expand_dims(queries, axis=1) # B * seq * K queries = tf.tile(queries, [1, self.max_seq_len, 1]) queries_keys_cross = tf.concat( [queries, keys, queries - keys, queries * keys], axis=2) mlp_layer = dense_nn(queries_keys_cross, (16,), use_bn=False, activation=tf.nn.sigmoid, name="attention") # B * seq * 1 mlp_layer = tf.layers.dense(mlp_layer, units=1, activation=None) # attention_weights = tf.transpose(mlp_layer, [0, 2, 1]) attention_weights = tf.layers.flatten(mlp_layer) key_masks = tf.sequence_mask(keys_len, self.max_seq_len) paddings = tf.ones_like(attention_weights) * (-2**32 + 1) attention_scores = tf.where(key_masks, attention_weights, paddings) attention_scores = tf.div_no_nan( attention_scores, tf.sqrt( tf.cast(keys.get_shape().as_list()[-1], tf.float32) ) ) # B * 1 * seq attention_scores = tf.expand_dims( tf.nn.softmax(attention_scores), 1) # B * 1 * K pooled_outputs = attention_scores @ keys return pooled_outputs
def spherical_normalization(x, rectify=True): """Apply area weights and normalization to spherical distributions. The sum of all pixel values over the spherical input will be one. Args: x: [BATCH, HEIGHT, WIDTH, CHANNELS] spherical raw distributions. rectify: apply softplus to the input x if true. Returns: [BATCH, HEIGHT, WIDTH, CHANNELS] normalized distributions. """ with tf.name_scope(None, 'spherical_normalization', [x]): # Apply softplus to make the input non-negative. shape = x.shape.as_list() height = shape[1] if rectify: x = tf.nn.softplus(x) weighted = x * equirectangular_area_weights(height) # Return shape [BATCH, HEIGHT, WIDTH, CHANNELS]. return tf.div_no_nan( x, tf.reduce_sum(weighted, axis=[1, 2], keepdims=True))
def model_fn(features, labels, mode, params): """Model function.""" del labels # ============================== # Input features # ============================== # [batch_size, query_seq_len] query_inputs = features["query_inputs"] # [batch_size, num_candidates, candidate_seq_len] candidate_inputs = features["candidate_inputs"] # [batch_size, num_candidates, query_seq_len + candidate_seq_len] joint_inputs = features["joint_inputs"] # [batch_size, num_masks] mlm_targets = features["mlm_targets"] mlm_positions = features["mlm_positions"] mlm_mask = features["mlm_mask"] # ============================== # Create modules. # ============================== bert_module = hub.Module( spec=params["bert_hub_module_handle"], name="locbert", tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {}, trainable=True) hub.register_module_for_export(bert_module, "locbert") embedder_module = hub.Module( spec=params["embedder_hub_module_handle"], name="embedder", tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {}, trainable=True) hub.register_module_for_export(embedder_module, "embedder") if params["share_embedders"]: query_embedder_module = embedder_module else: query_embedder_module = hub.Module( spec=params["embedder_hub_module_handle"], name="embedder", tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {}, trainable=True) hub.register_module_for_export(embedder_module, "query_embedder") # ============================== # Retrieve. # ============================== # [batch_size, projected_size] query_emb = query_embedder_module( inputs=dict( input_ids=query_inputs.token_ids, input_mask=query_inputs.mask, segment_ids=query_inputs.segment_ids), signature="projected") # [batch_size * num_candidates, candidate_seq_len] flat_candidate_inputs, unflatten = flatten_bert_inputs( candidate_inputs) # [batch_size * num_candidates, projected_size] flat_candidate_emb = embedder_module( inputs=dict( input_ids=flat_candidate_inputs.token_ids, input_mask=flat_candidate_inputs.mask, segment_ids=flat_candidate_inputs.segment_ids), signature="projected") # [batch_size, num_candidates, projected_size] unflattened_candidate_emb = unflatten(flat_candidate_emb) # [batch_size, num_candidates] retrieval_score = tf.einsum("BD,BND->BN", query_emb, unflattened_candidate_emb) # ============================== # Read. # ============================== # [batch_size * num_candidates, query_seq_len + candidate_seq_len] flat_joint_inputs, unflatten = flatten_bert_inputs(joint_inputs) # [batch_size * num_candidates, num_masks] flat_mlm_positions, _ = tensor_utils.flatten( tf.tile( tf.expand_dims(mlm_positions, 1), [1, params["num_candidates"], 1])) batch_size, num_masks = tensor_utils.shape(mlm_targets) # [batch_size * num_candidates, query_seq_len + candidates_seq_len] flat_joint_bert_outputs = bert_module( inputs=dict( input_ids=flat_joint_inputs.token_ids, input_mask=flat_joint_inputs.mask, segment_ids=flat_joint_inputs.segment_ids, mlm_positions=flat_mlm_positions), signature="mlm", as_dict=True) # [batch_size, num_candidates] candidate_score = retrieval_score # [batch_size, num_candidates] candidate_log_probs = tf.math.log_softmax(candidate_score) # ============================== # Compute marginal log-likelihood. # ============================== # [batch_size * num_candidates, num_masks] flat_mlm_logits = flat_joint_bert_outputs["mlm_logits"] # [batch_size, num_candidates, num_masks, vocab_size] mlm_logits = tf.reshape( flat_mlm_logits, [batch_size, params["num_candidates"], num_masks, -1]) mlm_log_probs = tf.math.log_softmax(mlm_logits) # [batch_size, num_candidates, num_masks] tiled_mlm_targets = tf.tile( tf.expand_dims(mlm_targets, 1), [1, params["num_candidates"], 1]) # [batch_size, num_candidates, num_masks, 1] tiled_mlm_targets = tf.expand_dims(tiled_mlm_targets, -1) # [batch_size, num_candidates, num_masks, 1] gold_log_probs = tf.batch_gather(mlm_log_probs, tiled_mlm_targets) # [batch_size, num_candidates, num_masks] gold_log_probs = tf.squeeze(gold_log_probs, -1) # [batch_size, num_candidates, num_masks] joint_gold_log_probs = ( tf.expand_dims(candidate_log_probs, -1) + gold_log_probs) # [batch_size, num_masks] marginal_gold_log_probs = tf.reduce_logsumexp(joint_gold_log_probs, 1) # [batch_size, num_masks] float_mlm_mask = tf.cast(mlm_mask, tf.float32) # [] loss = -tf.div_no_nan( tf.reduce_sum(marginal_gold_log_probs * float_mlm_mask), tf.reduce_sum(float_mlm_mask)) # ============================== # Optimization # ============================== num_warmup_steps = min(10000, max(100, int(params["num_train_steps"] / 10))) train_op = optimization.create_optimizer( loss=loss, init_lr=params["learning_rate"], num_train_steps=params["num_train_steps"], num_warmup_steps=num_warmup_steps, use_tpu=params["use_tpu"]) # ============================== # Evaluation # ============================== eval_metric_ops = None if params["use_tpu"] else dict() if mode != tf.estimator.ModeKeys.PREDICT: # [batch_size, num_masks] retrieval_utility = marginal_gold_log_probs - gold_log_probs[:, 0] retrieval_utility *= tf.cast(features["mlm_mask"], tf.float32) # [] retrieval_utility = tf.div_no_nan( tf.reduce_sum(retrieval_utility), tf.reduce_sum(float_mlm_mask)) add_mean_metric("retrieval_utility", retrieval_utility, eval_metric_ops) has_timestamp = tf.cast( tf.greater(features["export_timestamp"], 0), tf.float64) off_policy_delay_secs = ( tf.timestamp() - tf.cast(features["export_timestamp"], tf.float64)) off_policy_delay_mins = off_policy_delay_secs / 60.0 off_policy_delay_mins *= tf.cast(has_timestamp, tf.float64) add_mean_metric("off_policy_delay_mins", off_policy_delay_mins, eval_metric_ops) # Create empty predictions to avoid errors when running in prediction mode. predictions = dict() if params["use_tpu"]: return tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, predictions=predictions) else: if eval_metric_ops is not None: # Make sure the eval metrics are updated during training so that we get # quick feedback from tensorboard summaries when debugging locally. with tf.control_dependencies([u for _, u in eval_metric_ops.values()]): loss = tf.identity(loss) return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, predictions=predictions)
def metric_fn(per_example_loss, label_ids, logits, is_real_example): """Compute Matthew's correlations for COLA.""" predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) # https://en.wikipedia.org/wiki/Matthews_correlation_coefficient tp, tp_op = tf.metrics.true_positives( labels=label_ids, predictions=predictions, weights=is_real_example) tn, tn_op = tf.metrics.true_negatives( labels=label_ids, predictions=predictions, weights=is_real_example) fp, fp_op = tf.metrics.false_positives( labels=label_ids, predictions=predictions, weights=is_real_example) fn, fn_op = tf.metrics.false_negatives( labels=label_ids, predictions=predictions, weights=is_real_example) # computing precision, recall and f1 score # Added for BioAlbert precision = tf_metrics.precision(label_ids, predictions, num_labels, [1, 2], average="micro") recall = tf_metrics.recall(label_ids, predictions, num_labels, [1, 2], average="micro") f1 = tf_metrics.f1(label_ids, predictions, num_labels, [1, 2], average="micro") # Compute Matthew's correlation mcc = tf.div_no_nan( tp * tn - fp * fn, tf.pow((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn), 0.5)) # Compute accuracy accuracy = tf.metrics.accuracy(labels=label_ids, predictions=predictions, weights=is_real_example) loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) return { "matthew_corr": (mcc, tf.group(tp_op, tn_op, fp_op, fn_op)), "accuracy": accuracy, "eval_loss": loss, "precision": precision, "recall": recall, "f1_score": f1, }
def model_fn(features, labels, mode, params, vocab): """Model function that satisfies the Estimator API. Args: features: Dictionary of model input tensors. labels: Ununsed. mode: A tf.estimator.ModeKeys value. params: Dictionary of model parameters. vocab: A utils.text_utils.Vocab instance. Returns: spec: A tf.estimator.TPUEstimatorSpec. """ del labels # ---------------------------------------------------------------------------- # INITIALIZATION. # ---------------------------------------------------------------------------- # Update model config from the pre-trained checkpoint. model = transformer_utils.TransformerModel( config=transformer_utils.TransformerConfig.from_dict(params), is_training=(mode == tf_estimator.ModeKeys.TRAIN)) # Initialize QA model. rc_model = hub.Module(params["rc_model"]) # image_features: [batch_size, num_regions, feature_size] # image_positions: [batch_size, num_regions] # image_mask: [batch_size, num_regions] image_features = features["object_features"].features image_positions = features["object_features"].positions image_mask = features["object_features"].mask # Expand mask by 1 to account for the leading [IMG] token. # [batch_size, num_regions + 1] batch_size = tensor_utils.shape(image_mask, 0) input_mask = tf.pad(image_mask, [[0, 0], [1, 0]], constant_values=1) # Encode the image and store the cached transformer values. # [batch_size, num_regions + 1, num_layers, num_heads, head_size] _, input_cache = model.compute_image_transformer( input_ids=tf.fill([batch_size, 1], vocab.t2i(vocab.IMG)), input_image=image_features, input_image_mask=input_mask, input_positions=image_positions) # ---------------------------------------------------------------------------- # TRAINING # ---------------------------------------------------------------------------- if mode == tf_estimator.ModeKeys.TRAIN: # MIXER-style training objective consists of two parts: # 1) Policy gradient on rewarded rollouts. # 2) MLE regularization on references. # The full loss is L_total = L_pg + L_mle. # Step 1: Policy gradient. # Compute and score policy rollouts (multiple per image). rollouts = reward_utils.compute_rollouts(model=model, rc_model=rc_model, features=features, encoder_cache=input_cache, encoder_cache_mask=input_mask, vocab=vocab, params=params) # Using a self-critical baseline, R'(y) = R(y) - b where b = argmax p(y|x), # sample a single rollout with non-zero reward. rollout, reward = reward_utils.sample_from_rollouts( rollouts=rollouts, baseline=rollouts.rewards[params["reward"]][:, 0], reward_type=params["reward"]) # Compute the probablity of the rollout (back-propable). # [batch_size, decode_length, input_length + decode_length] rollout_attention_mask = transformer_utils.compute_attention_mask( token_mask=rollout.mask[:, :-1], input_mask=input_mask) # [batch_size, decode_length, vocab_size] rollout_emb, _ = model.compute_transformer( input_ids=rollout.token_ids[:, :-1], input_segment_id=rollout.segment_ids[:, :-1], input_positions=rollout.positions[:, :-1], attention_mask=rollout_attention_mask, input_cache=input_cache, reuse=tf.AUTO_REUSE) # [batch_size, decode_length, vocab_size] rollout_logits = model.compute_logits(rollout_emb, reuse=tf.AUTO_REUSE) # Compute the RL loss, -R(y) * log p(y|x) # Some elements in this batch are MLE only, mask those out from the loss. rollout_mask = tf.cast(rollout.mask[:, 1:], tf.float32) pg_mask = tf.equal(features["input_type"], datasets.DatasetTypes.VQA) rollout_mask *= tf.expand_dims(tf.cast(pg_mask, tf.float32), 1) rl_loss = tf.losses.sparse_softmax_cross_entropy( labels=rollout.token_ids[:, 1:], logits=rollout_logits, weights=tf.expand_dims(reward, 1) * rollout_mask, reduction=tf.losses.Reduction.SUM) rl_loss = tf.math.divide_no_nan(rl_loss, tf.reduce_sum(rollout_mask)) # Step 2: MLE on references. # [batch_size, decode_length, input_length + decode_length] reference_attention_mask = transformer_utils.compute_attention_mask( token_mask=features["token_inputs"].mask, input_mask=input_mask) # [batch_size, decode_length, hidden_size] target_emb, _ = model.compute_transformer( input_ids=features["token_inputs"].token_ids, input_segment_id=features["token_inputs"].segment_ids, input_positions=features["token_inputs"].positions, attention_mask=reference_attention_mask, input_cache=input_cache, reuse=tf.AUTO_REUSE) # [batch_size, decode_length, vocab_size] target_logits = model.compute_logits(target_emb, reuse=tf.AUTO_REUSE) # Compute the MLE objective (cross-entropy loss). weights = features["token_outputs"].mask ref_mask = tf.equal(features["input_type"], datasets.DatasetTypes.REFERENCE) weights *= tf.expand_dims(tf.cast(ref_mask, tf.int32), 1) reference_loss = tf.losses.sparse_softmax_cross_entropy( labels=features["token_outputs"].token_ids, logits=target_logits, weights=weights) # Add both losses together. loss = rl_loss + reference_loss # BERT-style optimization with linear warmp. train_op = optimization.create_optimizer( loss=loss, init_lr=params["learning_rate"], num_train_steps=params["num_train_steps"], num_warmup_steps=params["num_warmup_steps"], use_tpu=params.get("use_tpu")) # Book-keeping. summaries = tpu_summaries.TpuSummaries(params["model_dir"]) summaries.scalar("loss", loss) # Check what percentage of examples have non-zero reward. total_vqa = tf.reduce_sum(tf.cast(pg_mask, tf.float32)) nonzero = tf.cast(tf.not_equal(reward, 0), tf.float32) nonzero *= tf.cast(pg_mask, tf.float32) total_nonzero = tf.reduce_sum(nonzero) summaries.scalar("density", tf.div_no_nan(total_nonzero, total_vqa)) # Total (non-normalized) reward. reward = rollouts.rewards[params["reward"]][:, 0] reward *= tf.cast(pg_mask, tf.float32) total_reward = tf.reduce_sum(reward) summaries.scalar("reward", tf.div_no_nan(total_reward, total_vqa)) host_call = summaries.get_host_call() else: loss = None train_op = None host_call = None # ---------------------------------------------------------------------------- # TESTING. # ---------------------------------------------------------------------------- if mode == tf_estimator.ModeKeys.PREDICT: decode_output = transformer_utils.beam_search_decode( model=model, encoder_cache=input_cache, encoder_cache_mask=input_mask, start_id=vocab.t2i(vocab.CLS), stop_id=vocab.t2i(vocab.SEP), segment_id=0, num_steps=params["decode_length"], beam_size=params["beam_size"], alpha=params["beam_length_penalty"], reuse=tf.AUTO_REUSE) predictions = dict(image_id=features.get("image_id", -1), question_id=features.get("question_id", -1), token_ids=decode_output.token_ids[:, :, 1:]) else: predictions = None # ---------------------------------------------------------------------------- # WARM-START. # ---------------------------------------------------------------------------- # Initialize from pretrained model. def scaffold_fn(): """Init op run on host.""" checkpoint = params["base_model"] if params["warm_start_path"]: checkpoint = params["warm_start_path"] if checkpoint: checkpoint_utils.init_from_checkpoint(checkpoint) return tf.train.Scaffold() return tf_estimator.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, predictions=predictions, scaffold_fn=scaffold_fn, host_call=host_call, )
def _stochastic_prob(prob_mean, prob_std): prob_mean_mod = tf.math.log(tf.div_no_nan(prob_mean, 1 - prob_mean)) #sample_random_normal = tf.random.normal([], prob_mean_mod, prob_std, seed=1) sample_random_normal = tf.random.normal([], prob_mean_mod, prob_std) prob = tf.divide(1, 1 + tf.math.exp(-sample_random_normal)) return prob