def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] label_ids = features["label_ids"] if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) # Create model with aux loss model = GroverModel( config=config, is_training=is_training, input_ids=input_ids, pad_token_id=config.pad_token_id, chop_off_last_token=False, ) with tf.variable_scope('classification'): hidden_state = model.pooled_output(pool_token_id) if is_training: hidden_state = dropout(hidden_state, dropout_prob=0.1) logits = tf.layers.dense(hidden_state, num_labels, kernel_initializer=create_initializer( config.initializer_range), name='logits') log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=num_labels, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) class_loss = tf.reduce_mean(per_example_loss) total_loss = lm_loss_coef * model.lm_loss() + class_loss if is_training: train_op, train_metrics = optimization_adafactor.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) # tvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) tvars = tf.trainable_variables() train_metrics['minibatch_cls_loss'] = class_loss train_metrics['minibatch_acc'] = tf.reduce_mean( tf.cast( tf.equal(tf.argmax(logits, axis=-1, output_type=tf.int32), label_ids), tf.float32)) else: train_op = None train_metrics = {} tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names) = get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.debug("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.debug(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, host_call=construct_scalar_host_call( metric_dict=train_metrics, model_dir=params['model_dir'], prefix='training/'), scaffold_fn=scaffold_fn) else: output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=[ tf.train.LoggingTensorHook( {'loss': tf.metrics.mean(total_loss)[1]}, every_n_iter=100) ], scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, label_ids, logits, is_real_example): predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) accuracy = tf.metrics.accuracy(labels=label_ids, predictions=predictions, weights=is_real_example) loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) return { "eval_accuracy": accuracy, "eval_loss": loss, } eval_metrics = (metric_fn, [ per_example_loss, label_ids, logits, is_real_example ]) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, predictions={ 'logits': logits, 'probs': tf.nn.softmax(logits, axis=-1) }, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.debug(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = GroverModel( config=config, is_training=is_training, input_ids=input_ids, pad_token_id=config.pad_token_id, chop_off_last_token=True, ) total_loss = model.lm_loss() if is_training: train_op, train_metrics = optimization_adafactor.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) tvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) else: train_op = None train_metrics = {} tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names) = get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.debug("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.debug(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, host_call=construct_scalar_host_call( metric_dict=train_metrics, model_dir=params['model_dir'], prefix='training/'), scaffold_fn=scaffold_fn) else: output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=[ tf.train.LoggingTensorHook( {'loss': tf.metrics.mean(total_loss)[1]}, every_n_iter=100) ], scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(total_loss): loss = tf.metrics.mean(values=total_loss) return { "eval_loss": loss, } eval_metrics = (metric_fn, [total_loss]) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: gt_logprobs = tf.squeeze(tf.batch_gather( model.log_probs, model.target_ids[:, :, None]), axis=2) # Need top-p required under topp sampling! better_than_gt = model.log_probs > gt_logprobs[:, :, None] top_p_required = tf.reduce_sum( tf.cast(better_than_gt, tf.float32) * tf.exp(model.log_probs), axis=2) # No top-p sampling for now, since this seems to be too slow on TPUs if use_tpu: predictions = tf.reshape( tf.random.categorical(logits=model.logits_flat, num_samples=1), get_shape_list(model.target_ids), ) else: # Argmax # predictions = tf.math.argmax(model.log_probs, axis=-1, output_type=tf.int32) predictions = tf.reshape( _top_p_sample(model.logits_flat, num_samples=1, p=0.99)['sample'], get_shape_list(model.target_ids), ) pred_logprobs = tf.squeeze(tf.batch_gather(model.log_probs, predictions[:, :, None]), axis=2) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, predictions={ 'gt_logprobs': gt_logprobs, 'top_p_required': top_p_required, 'predictions': predictions, 'pred_logprobs': pred_logprobs, 'labels': input_ids }, scaffold_fn=scaffold_fn) return output_spec