def model_fn(features, mode, params): """The `model_fn` for TPUEstimator.""" label_ids = None if mode != tf_estimator.ModeKeys.PREDICT: label_ids = features["label"] model_config = runner_config["model_config"] loss, logits = create_model(model, model_config, features, mode, runner_config["name"]) if mode == tf_estimator.ModeKeys.TRAIN: train_op = create_optimizer(loss, runner_config, params) return tf_estimator.tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op) elif mode == tf_estimator.ModeKeys.EVAL: if not runner_config["model_config"]["multilabel"]: metric_fn = metric_functions.classification_metric else: metric_fn = metric_functions.labeling_metric eval_metrics = (metric_fn, [loss, label_ids, logits]) return tf_estimator.tpu.TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics) elif mode == tf_estimator.ModeKeys.PREDICT: predictions = {"logits": logits} if not runner_config["model_config"]["multilabel"]: predictions["predictions"] = tf.nn.softmax(logits) else: predictions["predictions"] = tf.math.sigmoid(logits) return tf_estimator.EstimatorSpec(mode=mode, predictions=predictions) else: assert False, "Expected to be called in TRAIN, EVAL, or PREDICT mode."
def estimator_spec_predict(self, features, mesh, mesh_impl, use_tpu): mtf_samples = mtf.anonymize(self.sample(features, mesh)) lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl}) outputs = lowering.export_to_tf_tensor(mtf_samples) if self.has_input: ndims = len(outputs.shape.as_list()) actual_batch_size = tf.shape(features["inputs"])[0] outputs = tf.slice(outputs, [0] * ndims, [actual_batch_size] + [-1] * (ndims - 1)) predictions = {"outputs": outputs} if features.get("infer_targets") is not None: predictions["infer_targets"] = features["infer_targets"] if features.get("inputs") is not None: predictions["inputs"] = features["inputs"] if use_tpu: t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf_estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) else: return tf_estimator.EstimatorSpec( tf_estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)])
def model_function(features, labels, mode, params): """Builds the `tf.estimator.EstimatorSpec` to train/eval with.""" is_train = mode == tf_estimator.ModeKeys.TRAIN logits = predict(is_train, embeddings, features["premise"], features["hypothesis"]) loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.to_int32(labels), logits=logits) loss = tf.reduce_mean(loss) if mode == tf_estimator.ModeKeys.TRAIN: train_op = get_train_op(loss) else: # Don't build the train_op unnecessarily, since the ADAM variables can # cause problems with loading checkpoints on CPUs. train_op = None metrics = dict(accuracy=tf.metrics.accuracy( tf.argmax(logits, 1, output_type=tf.int32), tf.to_int32(labels))) checkpoint_file = FLAGS.checkpoint_file if checkpoint_file is None: scaffold = None else: saver = tf.train.Saver(tf.trainable_variables()) def _init_fn(_, sess): saver.restore(sess, checkpoint_file) scaffold = tf.train.Scaffold(init_fn=_init_fn) return tf_estimator.EstimatorSpec(mode=mode, scaffold=scaffold, loss=loss, predictions=None, train_op=train_op, eval_metric_ops=metrics)
def model_fn(features, labels, mode, params): """Model function.""" reader_beam_size = params["reader_beam_size"] num_classes = params["num_classes"] if mode == tf_estimator.ModeKeys.PREDICT: retriever_beam_size = reader_beam_size else: retriever_beam_size = params["retriever_beam_size"] assert reader_beam_size <= retriever_beam_size with tf.device("/cpu:0"): retriever_outputs = orqa_model.retrieve( features=features, retriever_beam_size=retriever_beam_size, mode=mode, params=params) with tf.variable_scope("reader"): # [reader_beam_size, num_classes] final_logits = read( features=features, retriever_logits=retriever_outputs.logits[:reader_beam_size], blocks=retriever_outputs.blocks[:reader_beam_size], mode=mode, params=params) # [reader_beam_size] # We pick the most confident prediction amongst all retrievals. predictions = tf.argmax( tf.reshape(final_logits, [reader_beam_size * num_classes])) predictions = tf.math.floormod(predictions, num_classes) if mode == tf_estimator.ModeKeys.PREDICT: loss = None train_op = None eval_metric_ops = None else: labels = tf.cast(labels, tf.int32) eval_metric_ops = compute_eval_metrics( labels=labels, predictions=predictions) loss = compute_loss(labels, final_logits) train_op = optimization.create_optimizer( loss=loss, init_lr=params["learning_rate"], num_train_steps=params["num_train_steps"], num_warmup_steps=min(10000, max(100, int(params["num_train_steps"] / 10))), use_tpu=False) return tf_estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, predictions={"answer": predictions}, eval_metric_ops=eval_metric_ops)
def model_fn(features, labels, mode): """The model function for creating an Estimtator.""" del labels input_count = tf.reduce_sum( tf.to_int32(tf.greater(features["input_refs"][:, :, 1], features["input_refs"][:, :, 0]))) tf.summary.scalar("input_count", input_count) loss_dict, pred_dict, areas = seq2act_model.core_graph( features, hparams, mode, compute_additional_loss_fn) if mode == tf_estimator.ModeKeys.PREDICT: pred_dict["sequences"] = decode_sequence( features, areas, hparams, decode_length, post_processing=FLAGS.post_processing) return tf_estimator.EstimatorSpec(mode, predictions=pred_dict) elif mode == tf_estimator.ModeKeys.EVAL: metrics = {} _eval(metrics, pred_dict, loss_dict, features, areas, compute_seq_accuracy, hparams, metric_types=FLAGS.metric_types.split(","), decode_length=decode_length) if compute_additional_metric_fn: compute_additional_metric_fn(metrics, pred_dict, features) return tf_estimator.EstimatorSpec( mode, loss=loss_dict["total_loss"], eval_metric_ops=metrics) else: assert mode == tf_estimator.ModeKeys.TRAIN loss = loss_dict["total_loss"] for loss_name in loss_dict: if loss_name == "total_loss": continue if loss_name.endswith("losses"): continue tf.summary.scalar(loss_name, loss_dict[loss_name]) step_num = tf.to_float(tf.train.get_global_step()) schedule_string = hparams.learning_rate_schedule names = schedule_string.split("*") names = [name.strip() for name in names if name.strip()] ret = tf.constant(1.0) for name in names: ret *= learning_rate.learning_rate_factor(name, step_num, hparams) train_op = optimize.optimize(loss, ret, hparams) return tf_estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
def _build_estimator_spec(losses, trainer_params, mode, use_tpu=False): """Builds an EstimatorSpec/TPUEstimatorSpec based on trainer_params. Args: losses: A dictionary of {string: tf.Tensor} containing the various losses. The keys will be used as display names for the summaries, the values will be summed up to obtain the total loss, which is to be minimized. trainer_params: A ParameterContainer object with parameters relevant to the training. mode: One of tf.estimator.ModeKeys: TRAIN, PREDICT or EVAL. use_tpu: A boolean, if True, a TPU-compatible version of EstimatorSpec will be built. Returns: A EstimatorSpec or a TPUEstimatorSpec object. """ if mode == tf_estimator.ModeKeys.TRAIN: total_loss = 0.0 for loss_name, loss in six.iteritems(losses): if not use_tpu: tf.summary.scalar('Loss/%s' % loss_name, loss) total_loss += loss learning_rate = trainer_params.learning_rate maybe_summary.scalar('Learning Rate', learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.9) optimizer = contrib_estimator.clip_gradients_by_norm( optimizer, trainer_params.clip_gradients) if use_tpu: optimizer = tf.tpu.CrossShardOptimizer(optimizer) train_op = optimizer.minimize(total_loss, global_step=tf.train.get_global_step()) else: total_loss = None train_op = None if use_tpu: estimator_spec = tf_estimator.tpu.TPUEstimatorSpec( mode=tf_estimator.ModeKeys.TRAIN, loss=total_loss, train_op=train_op) else: estimator_spec = tf_estimator.EstimatorSpec( mode=tf_estimator.ModeKeys.TRAIN, loss=total_loss, train_op=train_op) return estimator_spec
def estimator_spec_eval(self, features, logits, labels, loss, restore_hook, use_tpu): """Construct EstimatorSpec for EVAL mode.""" hparams = self.hparams problem = hparams.problem if logits.get_shape().ndims == 3: logits = tf.expand_dims(tf.expand_dims(logits, 2), 3) # Support for multiproblem task_list = [problem] if hasattr(problem, "task_list"): task_list = problem.task_list eval_metrics_fns = metrics.create_evaluation_metrics( task_list, hparams) if use_tpu: def metric_fn(tf_logits, labels): with tf.device("cpu:0"), mtf.utils.outside_all_rewrites(): eval_metrics = {} for metric_name, metric_fn in six.iteritems( eval_metrics_fns): if metric_name.split( "/")[-1] not in t2t_model.TPU_METRIC_BLACKLIST: eval_metrics[metric_name] = metric_fn( tf_logits, None, tf.identity(labels)) return eval_metrics return tpu_estimator.TPUEstimatorSpec( tf_estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=loss, eval_metrics=(metric_fn, [logits, labels])) else: eval_metrics = {} predictions = {"predictions": logits} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): eval_metrics[metric_name] = metric_fn(logits, features, features["targets"]) return tf_estimator.EstimatorSpec(tf_estimator.ModeKeys.EVAL, predictions=predictions, eval_metric_ops=eval_metrics, evaluation_hooks=[restore_hook], loss=loss)
def _model_fn(features, labels, mode): """A model_fn that uses a mock TF-Hub module.""" del labels spec = hub.create_module_spec(text_module_fn) embedding = hub.Module(spec) if register_module: hub.register_module_for_export(embedding, _EXPORT_MODULE_NAME) predictions = embedding(features[_TEXT_FEATURE_NAME]) loss = tf.constant(0.0) global_step = tf.compat.v1.train.get_global_step() train_op = tf.compat.v1.assign_add(global_step, 1) return tf_estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, train_op=train_op)
def _simple_model_fn(self, features, labels, mode, params): logits = tf.squeeze(tf.layers.dense(features, 1)) loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.to_float(labels), logits=logits)) optimizer = tf.train.GradientDescentOptimizer(0.1) if params["use_tpu"]: optimizer = tf.tpu.CrossShardOptimizer(optimizer) train_op = optimizer.minimize( loss, global_step=tf.train.get_or_create_global_step()) if params["use_tpu"]: return tf_estimator.tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op) else: return tf_estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
def _model_fn(features, labels, params, mode=None): """Returns tf.estimator.EstimatorSpec.""" num_output_classes = len(label_vocab) predictions, predictions_for_loss = _make_prediction_ops( features=features, hparams=params, mode=mode, num_output_classes=num_output_classes) evaluation_hooks = [] if mode == tf_estimator.ModeKeys.TRAIN: loss = _make_loss(predictions_for_loss=predictions_for_loss, labels=labels, num_output_classes=num_output_classes) train_op = _make_train_op(loss=loss, hparams=params) eval_ops = None elif mode == tf_estimator.ModeKeys.PREDICT: loss = None train_op = None eval_ops = None else: # Eval mode. loss = _make_loss(predictions_for_loss=predictions_for_loss, labels=labels, num_output_classes=num_output_classes) train_op = None eval_ops = None return tf_estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=eval_ops, evaluation_hooks=evaluation_hooks, )
def _gpu_estimator_spec_eval(self, features, logits, labels, loss, losses_dict): """Construct EstimatorSpec for GPU EVAL mode.""" hparams = self.hparams if not hasattr(hparams, "problem"): raise NotImplementedError( "hparams is missing attribute `problem`. NasSeq2Seq must " "be used with a problem.") # TPU is not supported. eval_metrics_fns = metrics.create_evaluation_metrics([hparams.problem], hparams) eval_metrics = {} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): if "rouge" not in metric_name and "bleu" not in metric_name: eval_metrics[metric_name] = metric_fn(logits, features, features["targets"]) return tf_estimator.EstimatorSpec( tf_estimator.ModeKeys.EVAL, predictions={"predictions": logits}, eval_metric_ops=eval_metrics, loss=loss)
def model_fn(features, labels, mode, params): """The `model_fn` for tf.Estimator.""" del labels, params if mode != tf_estimator.ModeKeys.PREDICT: raise ValueError("Only PREDICT mode is supported: %s" % (mode)) tf.logging.info("*** Features *** %s %s" % (type(features), features)) for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) guid = features["guid"] input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] logits = create_model( bert_config=bert_config, is_training=False, fewshot_num_examples_per_class=fewshot_num_examples_per_class, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, use_one_hot_embeddings=use_one_hot_embeddings, tokenizer=tokenizer, class_examples_combiner=class_examples_combiner) predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) output_spec = tf_estimator.EstimatorSpec(mode=mode, predictions={ "predictions": predictions, "guid": guid, }) return output_spec
def model_fn(features, labels, mode, params): """Model function.""" del labels # [local_batch_size, block_seq_len] block_ids = features["block_ids"] block_mask = features["block_mask"] block_segment_ids = features["block_segment_ids"] # [local_batch_size, query_seq_len] query_ids = features["query_ids"] query_mask = features["query_mask"] local_batch_size = tensor_utils.shape(block_ids, 0) tf.logging.info("Model batch size: %d", local_batch_size) ict_module = create_ict_module(params, mode) query_emb = ict_module(inputs=dict(input_ids=query_ids, input_mask=query_mask, segment_ids=tf.zeros_like(query_ids)), signature="projected") block_emb = ict_module(inputs=dict(input_ids=block_ids, input_mask=block_mask, segment_ids=block_segment_ids), signature="projected") if params["use_tpu"]: # [global_batch_size, hidden_size] block_emb = tpu_utils.cross_shard_concat(block_emb) # [global_batch_size, local_batch_size] labels = tpu_utils.cross_shard_pad(tf.eye(local_batch_size)) # [local_batch_size] labels = tf.argmax(labels, 0) else: # [local_batch_size] labels = tf.range(local_batch_size) tf.logging.info("Global batch size: %s", tensor_utils.shape(block_emb, 0)) # [batch_size, global_batch_size] logits = tf.matmul(query_emb, block_emb, transpose_b=True) # [] loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) train_op = optimization.create_optimizer( loss=loss, init_lr=params["learning_rate"], num_train_steps=params["num_train_steps"], num_warmup_steps=min(10000, max(100, int(params["num_train_steps"] / 10))), use_tpu=params["use_tpu"] if "use_tpu" in params else False) predictions = tf.argmax(logits, -1) metric_args = [ query_mask, block_mask, labels, predictions, features["mask_query"] ] def metric_fn(query_mask, block_mask, labels, predictions, mask_query): masked_accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions, weights=mask_query) unmasked_accuracy = tf.metrics.accuracy( labels=labels, predictions=predictions, weights=tf.logical_not(mask_query)) return dict(query_non_padding=tf.metrics.mean(query_mask), block_non_padding=tf.metrics.mean(block_mask), actual_mask_ratio=tf.metrics.mean(mask_query), masked_accuracy=masked_accuracy, unmasked_accuracy=unmasked_accuracy) if params["use_tpu"]: return tf_estimator.tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metrics=(metric_fn, metric_args)) else: return tf_estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, eval_metric_ops=metric_fn(*metric_args), predictions=predictions)
def estimator_spec_train(self, loss, num_async_replicas=1, use_tpu=False): """Constructs `tf.estimator.EstimatorSpec` for TRAIN (training) mode.""" train_op = self.optimize(loss, num_async_replicas=num_async_replicas, use_tpu=use_tpu) sparsity_technique = self._hparams.get("sparsity_technique") if "pruning" in sparsity_technique: if not self._hparams.load_masks_from: # If we are loading trained masks, don't add the mask update # step to the training process and keep the masks static with tf.control_dependencies([train_op]): mp_hparams = pruning_hparams( self._hparams, use_tpu, sparsity_technique == "random_pruning") p = magnitude_pruning.Pruning( mp_hparams, global_step=tf.train.get_global_step()) mask_update_op = p.conditional_mask_update_op() train_op = mask_update_op check_global_sparsity() if use_tpu: if self._hparams.warm_start_from: def scaffold_fn(): self.initialize_from_ckpt(self._hparams.warm_start_from) return tf.train.Scaffold() elif self._hparams.load_masks_from and self._hparams.load_weights_from: def scaffold_fn(): self.initialize_masks_from_ckpt( self._hparams.load_masks_from) self.initialize_non_masks_from_ckpt( self._hparams.load_weights_from) return tf.train.Scaffold() elif self._hparams.load_masks_from: def scaffold_fn(): self.initialize_masks_from_ckpt( self._hparams.load_masks_from) return tf.train.Scaffold() else: scaffold_fn = None # Note: important to call this before remove_summaries() if self.hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(self.hparams.model_dir) else: host_call = None t2t_model.remove_summaries() return contrib_tpu.TPUEstimatorSpec(tf_estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op, host_call=host_call, scaffold_fn=scaffold_fn) else: if self._hparams.warm_start_from: self.initialize_from_ckpt(self._hparams.warm_start_from) elif self._hparams.load_masks_from: self.initialize_masks_from_ckpt(self._hparams.load_masks_from) return tf_estimator.EstimatorSpec(tf_estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op)
def _model_fn(features, labels, mode, params, config): """Defines an `Estimator` `model_fn`.""" del [config, params] # In Estimator, all sub-graphs need to be constructed inside the model_fn. # Hence, ranker, losses, metrics and optimizer are cloned inside this # function. ranker = tf.keras.models.clone_model(model, clone_function=_clone_fn) training = (mode == tf_estimator.ModeKeys.TRAIN) weights = None if weights_feature_name and mode != tf_estimator.ModeKeys.PREDICT: if weights_feature_name not in features: raise ValueError( "weights_feature '{0}' can not be found in 'features'.". format(weights_feature_name)) else: weights = utils.reshape_to_2d( features.pop(weights_feature_name)) logits = ranker(features, training=training) if serving_default not in ["regress", "predict"]: raise ValueError( "serving_default should be 'regress' or 'predict', " "but got {}".format(serving_default)) if serving_default == "regress": default_export_output = tf_estimator.export.RegressionOutput( logits) else: default_export_output = tf_estimator.export.PredictOutput(logits) export_outputs = { _DEFAULT_SERVING_KEY: default_export_output, _REGRESS_SERVING_KEY: tf_estimator.export.RegressionOutput(logits), _PREDICT_SERVING_KEY: tf_estimator.export.PredictOutput(logits) } if mode == tf_estimator.ModeKeys.PREDICT: return tf_estimator.EstimatorSpec(mode=mode, predictions=logits, export_outputs=export_outputs) loss = _clone_fn(model.loss) total_loss = loss(labels, logits, sample_weight=weights) keras_metrics = [] for metric in model.metrics: keras_metrics.append(_clone_fn(metric)) # Adding default metrics here as model.metrics does not contain custom # metrics. keras_metrics += metrics.default_keras_metrics() eval_metric_ops = {} for keras_metric in keras_metrics: keras_metric.update_state(labels, logits, sample_weight=weights) eval_metric_ops[keras_metric.name] = keras_metric train_op = None if training: optimizer = _clone_fn(model.optimizer) optimizer.iterations = tf.compat.v1.train.get_or_create_global_step( ) # Get both the unconditional updates (the None part) # and the input-conditional updates (the features part). # These updates are for layers like BatchNormalization, which have # separate update and minimize ops. update_ops = ranker.get_updates_for(None) + ranker.get_updates_for( features) minimize_op = optimizer.get_updates( loss=total_loss, params=ranker.trainable_variables)[0] train_op = tf.group(minimize_op, *update_ops) return tf_estimator.EstimatorSpec(mode=mode, predictions=logits, loss=total_loss, train_op=train_op, eval_metric_ops=eval_metric_ops, export_outputs=export_outputs)
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" tf.logging.info("features = %s labels = %s mode = %s params=%s" % (features, labels, mode, params)) global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") logits, loss = mnist_model(features, labels, mesh) mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) mesh_size = mesh_shape.size mesh_devices = [""] * mesh_size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) if mode == tf_estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_logits = lowering.export_to_tf_tensor(logits) if mode != tf_estimator.ModeKeys.PREDICT: tf_loss = lowering.export_to_tf_tensor(loss) tf.summary.scalar("loss", tf_loss) if mode == tf_estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) accuracy = tf.metrics.accuracy( labels=labels, predictions=tf.argmax(tf_logits, axis=1)) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "cross_entropy") tf.identity(accuracy[1], name="train_accuracy") # Save accuracy scalar to Tensorboard output. tf.summary.scalar("train_accuracy", accuracy[1]) # restore_hook must come before saver_hook return tf_estimator.EstimatorSpec( tf_estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) if mode == tf_estimator.ModeKeys.PREDICT: predictions = { "classes": tf.argmax(tf_logits, axis=1), "probabilities": tf.nn.softmax(tf_logits), } return tf_estimator.EstimatorSpec( mode=tf_estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[restore_hook], export_outputs={ "classify": tf_estimator.export.PredictOutput(predictions) }) if mode == tf_estimator.ModeKeys.EVAL: return tf_estimator.EstimatorSpec( mode=tf_estimator.ModeKeys.EVAL, loss=tf_loss, evaluation_hooks=[restore_hook], eval_metric_ops={ "accuracy": tf.metrics.accuracy( labels=labels, predictions=tf.argmax(tf_logits, axis=1)), })
def model_fn(features, labels, mode, params, config): """Builds the acoustic model.""" del config hparams = params length = features.length spec = features.spec is_training = mode == tf_estimator.ModeKeys.TRAIN if is_training: onset_labels = labels.onsets offset_labels = labels.offsets velocity_labels = labels.velocities frame_labels = labels.labels frame_label_weights = labels.label_weights if hparams.stop_activation_gradient and not hparams.activation_loss: raise ValueError( 'If stop_activation_gradient is true, activation_loss must be true.' ) losses = {} with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training): with tf.variable_scope('onsets'): onset_outputs = acoustic_model(spec, hparams, lstm_units=hparams.onset_lstm_units, lengths=length) onset_probs = slim.fully_connected(onset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='onset_probs') # onset_probs_flat is used during inference. onset_probs_flat = flatten_maybe_padded_sequences( onset_probs, length) if is_training: onset_labels_flat = flatten_maybe_padded_sequences( onset_labels, length) onset_losses = tf_utils.log_loss(onset_labels_flat, onset_probs_flat) tf.losses.add_loss(tf.reduce_mean(onset_losses)) losses['onset'] = onset_losses with tf.variable_scope('offsets'): offset_outputs = acoustic_model( spec, hparams, lstm_units=hparams.offset_lstm_units, lengths=length) offset_probs = slim.fully_connected(offset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='offset_probs') # offset_probs_flat is used during inference. offset_probs_flat = flatten_maybe_padded_sequences( offset_probs, length) if is_training: offset_labels_flat = flatten_maybe_padded_sequences( offset_labels, length) offset_losses = tf_utils.log_loss(offset_labels_flat, offset_probs_flat) tf.losses.add_loss(tf.reduce_mean(offset_losses)) losses['offset'] = offset_losses with tf.variable_scope('velocity'): velocity_outputs = acoustic_model( spec, hparams, lstm_units=hparams.velocity_lstm_units, lengths=length) velocity_values = slim.fully_connected(velocity_outputs, constants.MIDI_PITCHES, activation_fn=None, scope='onset_velocities') velocity_values_flat = flatten_maybe_padded_sequences( velocity_values, length) if is_training: velocity_labels_flat = flatten_maybe_padded_sequences( velocity_labels, length) velocity_loss = tf.reduce_sum( onset_labels_flat * tf.square(velocity_labels_flat - velocity_values_flat), axis=1) tf.losses.add_loss(tf.reduce_mean(velocity_loss)) losses['velocity'] = velocity_loss with tf.variable_scope('frame'): if not hparams.share_conv_features: # TODO(eriche): this is broken when hparams.frame_lstm_units > 0 activation_outputs = acoustic_model( spec, hparams, lstm_units=hparams.frame_lstm_units, lengths=length) activation_probs = slim.fully_connected( activation_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='activation_probs') else: activation_probs = slim.fully_connected( onset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='activation_probs') probs = [] if hparams.stop_onset_gradient: probs.append(tf.stop_gradient(onset_probs)) else: probs.append(onset_probs) if hparams.stop_activation_gradient: probs.append(tf.stop_gradient(activation_probs)) else: probs.append(activation_probs) if hparams.stop_offset_gradient: probs.append(tf.stop_gradient(offset_probs)) else: probs.append(offset_probs) combined_probs = tf.concat(probs, 2) if hparams.combined_lstm_units > 0: outputs = lstm_layer( combined_probs, hparams.combined_lstm_units, lengths=length if hparams.use_lengths else None, stack_size=hparams.combined_rnn_stack_size, use_cudnn=hparams.use_cudnn, bidirectional=hparams.bidirectional) else: outputs = combined_probs frame_probs = slim.fully_connected(outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='frame_probs') frame_probs_flat = flatten_maybe_padded_sequences(frame_probs, length) if is_training: frame_labels_flat = flatten_maybe_padded_sequences( frame_labels, length) frame_label_weights_flat = flatten_maybe_padded_sequences( frame_label_weights, length) if hparams.weight_frame_and_activation_loss: frame_loss_weights = frame_label_weights_flat else: frame_loss_weights = None frame_losses = tf_utils.log_loss(frame_labels_flat, frame_probs_flat, weights=frame_loss_weights) tf.losses.add_loss(tf.reduce_mean(frame_losses)) losses['frame'] = frame_losses if hparams.activation_loss: if hparams.weight_frame_and_activation_loss: activation_loss_weights = frame_label_weights else: activation_loss_weights = None activation_losses = tf_utils.log_loss( frame_labels_flat, flatten_maybe_padded_sequences(activation_probs, length), weights=activation_loss_weights) tf.losses.add_loss(tf.reduce_mean(activation_losses)) losses['activation'] = activation_losses frame_predictions = frame_probs_flat > hparams.predict_frame_threshold onset_predictions = onset_probs_flat > hparams.predict_onset_threshold offset_predictions = offset_probs_flat > hparams.predict_offset_threshold frame_predictions = tf.expand_dims(frame_predictions, axis=0) onset_predictions = tf.expand_dims(onset_predictions, axis=0) offset_predictions = tf.expand_dims(offset_predictions, axis=0) velocity_values = tf.expand_dims(velocity_values_flat, axis=0) metrics_values = metrics.define_metrics( frame_probs=frame_probs, onset_probs=onset_probs, frame_predictions=frame_predictions, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values, length=features.length, sequence_label=labels.note_sequence, frame_labels=labels.labels, sequence_id=features.sequence_id, hparams=hparams) for label, loss_collection in losses.items(): loss_label = 'losses/' + label metrics_values[loss_label] = loss_collection def predict_sequence(): """Convert frame predictions into a sequence (TF).""" def _predict(frame_probs, onset_probs, frame_predictions, onset_predictions, offset_predictions, velocity_values): """Convert frame predictions into a sequence (Python).""" sequence = infer_util.predict_sequence( frame_probs=frame_probs, onset_probs=onset_probs, frame_predictions=frame_predictions, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values, hparams=hparams, min_pitch=constants.MIN_MIDI_PITCH) return sequence.SerializeToString() sequence = tf.py_func(_predict, inp=[ frame_probs[0], onset_probs[0], frame_predictions[0], onset_predictions[0], offset_predictions[0], velocity_values[0], ], Tout=tf.string, stateful=False) sequence.set_shape([]) return tf.expand_dims(sequence, axis=0) predictions = { 'frame_probs': frame_probs, 'onset_probs': onset_probs, 'frame_predictions': frame_predictions, 'onset_predictions': onset_predictions, 'offset_predictions': offset_predictions, 'velocity_values': velocity_values, 'sequence_predictions': predict_sequence(), # Include some features and labels in output because Estimator 'predict' # API does not give access to them. 'sequence_ids': features.sequence_id, 'sequence_labels': labels.note_sequence, 'frame_labels': labels.labels, 'onset_labels': labels.onsets, } for k, v in metrics_values.items(): predictions[k] = tf.stack(v) metric_ops = {k: tf.metrics.mean(v) for k, v in metrics_values.items()} train_op = None loss = None if is_training: # Creates a pianoroll labels in red and probs in green [minibatch, 88] images = {} onset_pianorolls = tf.concat([ onset_labels[:, :, :, tf.newaxis], onset_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(onset_labels))[:, :, :, tf.newaxis] ], axis=3) images['OnsetPianorolls'] = onset_pianorolls offset_pianorolls = tf.concat([ offset_labels[:, :, :, tf.newaxis], offset_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(offset_labels))[:, :, :, tf.newaxis] ], axis=3) images['OffsetPianorolls'] = offset_pianorolls activation_pianorolls = tf.concat([ frame_labels[:, :, :, tf.newaxis], frame_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(frame_labels))[:, :, :, tf.newaxis] ], axis=3) images['ActivationPianorolls'] = activation_pianorolls for name, image in images.items(): tf.summary.image(name, image) loss = tf.losses.get_total_loss() tf.summary.scalar('loss', loss) for label, loss_collection in losses.items(): loss_label = 'losses/' + label tf.summary.scalar(loss_label, tf.reduce_mean(loss_collection)) train_op = slim.optimize_loss( name='training', loss=loss, global_step=tf.train.get_or_create_global_step(), learning_rate=hparams.learning_rate, learning_rate_decay_fn=functools.partial( tf.train.exponential_decay, decay_steps=hparams.decay_steps, decay_rate=hparams.decay_rate, staircase=True), clip_gradients=hparams.clip_norm, optimizer='Adam') return tf_estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metric_ops)
def model_fn(features, labels, mode, params): """Model function.""" if labels is None: labels = tf.constant([""]) reader_beam_size = params["reader_beam_size"] if mode == tf_estimator.ModeKeys.PREDICT: retriever_beam_size = reader_beam_size else: retriever_beam_size = params["retriever_beam_size"] assert reader_beam_size <= retriever_beam_size with tf.device("/cpu:0"): retriever_outputs = retrieve(features=features, retriever_beam_size=retriever_beam_size, mode=mode, params=params) with tf.variable_scope("reader"): reader_outputs = read( features=features, retriever_logits=retriever_outputs.logits[:reader_beam_size], blocks=retriever_outputs.blocks[:reader_beam_size], mode=mode, params=params, labels=labels) predictions = get_predictions(reader_outputs, params) if mode == tf_estimator.ModeKeys.PREDICT: loss = None train_op = None eval_metric_ops = None else: # [retriever_beam_size] retriever_correct = orqa_ops.has_answer( blocks=retriever_outputs.blocks, answers=labels) # [reader_beam_size, num_candidates] reader_correct = compute_correct_candidates( candidate_starts=reader_outputs.candidate_starts, candidate_ends=reader_outputs.candidate_ends, gold_starts=reader_outputs.gold_starts, gold_ends=reader_outputs.gold_ends) eval_metric_ops = compute_eval_metrics( labels=labels, predictions=predictions, retriever_correct=retriever_correct, reader_correct=reader_correct) # [] loss = compute_loss(retriever_logits=retriever_outputs.logits, retriever_correct=retriever_correct, reader_logits=reader_outputs.logits, reader_correct=reader_correct) train_op = optimization.create_optimizer( loss=loss, init_lr=params["learning_rate"], num_train_steps=params["num_train_steps"], num_warmup_steps=min(10000, max(100, int(params["num_train_steps"] / 10))), use_tpu=False) return tf_estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, predictions=predictions, eval_metric_ops=eval_metric_ops)
def resnet_model_fn(features, labels, mode, params): """The model_fn for ResNet to be used with TPUEstimator. Args: features: `Tensor` of batched images. If transpose_input is enabled, it is transposed to device layout and reshaped to 1D tensor. labels: `Tensor` of labels for the data samples mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}` params: `dict` of parameters passed to the model from the TPUEstimator, `params['batch_size']` is always provided and should be used as the effective batch size. Returns: A `TPUEstimatorSpec` for the model """ if isinstance(features, dict): features = features['feature'] # In most cases, the default data format NCHW instead of NHWC should be # used for a significant performance boost on GPU/TPU. NHWC should be used # only if the network needs to be run on CPU since the pooling operations # are only supported on NHWC. if params['data_format'] == 'channels_first': assert not params['transpose_input'] # channels_first only for GPU features = tf.transpose(features, [0, 3, 1, 2]) if params['transpose_input'] and mode != tf_estimator.ModeKeys.PREDICT: image_size = tf.sqrt(tf.shape(features)[0] / (3 * tf.shape(labels)[0])) features = tf.reshape(features, [image_size, image_size, 3, -1]) features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC # Normalize the image to zero mean and unit variance. features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype) features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype) # DropBlock keep_prob for the 4 block groups of ResNet architecture. # None means applying no DropBlock at the corresponding block group. dropblock_keep_probs = [None] * 4 if params['dropblock_groups']: # Scheduled keep_prob for DropBlock. train_steps = tf.cast(params['train_steps'], tf.float32) current_step = tf.cast(tf.train.get_global_step(), tf.float32) current_ratio = current_step / train_steps dropblock_keep_prob = (1 - current_ratio * (1 - params['dropblock_keep_prob'])) # Computes DropBlock keep_prob for different block groups of ResNet. dropblock_groups = [int(x) for x in params['dropblock_groups'].split(',')] for block_group in dropblock_groups: if block_group < 1 or block_group > 4: raise ValueError( 'dropblock_groups should be a comma separated list of integers ' 'between 1 and 4 (dropblcok_groups: {}).'.format( params['dropblock_groups'])) dropblock_keep_probs[block_group - 1] = 1 - ( (1 - dropblock_keep_prob) / 4.0**(4 - block_group)) # This nested function allows us to avoid duplicating the logic which # builds the network, for different values of --precision. def build_network(): network = resnet_model.resnet_v1( resnet_depth=params['resnet_depth'], num_classes=params['num_label_classes'], dropblock_size=params['dropblock_size'], dropblock_keep_probs=dropblock_keep_probs, data_format=params['data_format']) return network( inputs=features, is_training=(mode == tf_estimator.ModeKeys.TRAIN)) if params['precision'] == 'bfloat16': with contrib_tpu.bfloat16_scope(): logits = build_network() logits = tf.cast(logits, tf.float32) elif params['precision'] == 'float32': logits = build_network() if mode == tf_estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf_estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf_estimator.export.PredictOutput(predictions) }) # If necessary, in the model_fn, use params['batch_size'] instead the batch # size flags (--train_batch_size or --eval_batch_size). batch_size = params['batch_size'] # pylint: disable=unused-variable # Calculate loss, which includes softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, params['num_label_classes']) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels, label_smoothing=params['label_smoothing']) # Add weight decay to the loss for non-batch-normalization variables. loss = cross_entropy + params['weight_decay'] * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) host_call = None if mode == tf_estimator.ModeKeys.TRAIN: # Compute the current epoch and associated learning rate from global_step. global_step = tf.train.get_global_step() steps_per_epoch = params['num_train_images'] / params['train_batch_size'] current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch) # LARS is a large batch optimizer. LARS enables higher accuracy at batch 16K # and larger batch sizes. if params['enable_lars']: learning_rate = 0.0 optimizer = lars_util.init_lars_optimizer(current_epoch, params) raise ValueError('LARS unexpected in the context of IGT experiments.') else: learning_rate = linear_learning_rate_schedule(params, global_step) if FLAGS.optimizer == 'momentum': tf.logging.info('Using MomentumOptimizer ({}).'.format( params['momentum'])) optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=params['momentum'], use_nesterov=False) elif FLAGS.optimizer == 'adam': tf.logging.info('Using AdamOptimizer') optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) elif FLAGS.optimizer == 'eigt': tf.logging.info('Using ExpIgtOptimizer {} tail: {}'.format( FLAGS.igt_optimizer, FLAGS.tail_fraction)) optimizer = exp_igt_optimizer.ExpIgtOptimizer( learning_rate, tail_fraction=FLAGS.tail_fraction, optimizer=FLAGS.igt_optimizer) else: raise ValueError('{} is not a supported optimizer'.format( FLAGS.optimizer)) if params['use_tpu']: # When using TPU, wrap the optimizer with CrossShardOptimizer which # handles synchronization details between different TPU cores. To the # user, this should look like regular synchronous training. optimizer = contrib_tpu.CrossShardOptimizer(optimizer) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) if not params['skip_host_call']: def host_call_fn(gs, loss, lr, ce): """Training host call. Creates scalar summaries for training metrics. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `host_call`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `host_call`. Args: gs: `Tensor with shape `[batch]` for the global_step loss: `Tensor` with shape `[batch]` for the training loss. lr: `Tensor` with shape `[batch]` for the learning_rate. ce: `Tensor` with shape `[batch]` for the current_epoch. Returns: List of summary ops to run on the CPU host. """ gs = gs[0] # Host call fns are executed params['iterations_per_loop'] times after # one TPU loop is finished, setting max_queue value to the same as # number of iterations will make the summary writer only flush the data # to storage once per loop. with summary.create_file_writer( get_model_dir(params), max_queue=params['iterations_per_loop']).as_default(): with summary.always_record_summaries(): summary.scalar('loss', loss[0], step=gs) summary.scalar('learning_rate', lr[0], step=gs) summary.scalar('current_epoch', ce[0], step=gs) return summary.all_summary_ops() # To log the loss, current learning rate, and epoch for Tensorboard, the # summary op needs to be run on the host CPU via host_call. host_call # expects [batch_size, ...] Tensors, thus reshape to introduce a batch # dimension. These Tensors are implicitly concatenated to # [params['batch_size']]. gs_t = tf.reshape(global_step, [1]) loss_t = tf.reshape(loss, [1]) lr_t = tf.reshape(learning_rate, [1]) ce_t = tf.reshape(current_epoch, [1]) host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t]) else: train_op = None eval_metrics = None scaffold_fn = None if mode == tf_estimator.ModeKeys.EVAL: def metric_fn(labels, logits): """Evaluation metric function. Evaluates accuracy. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `eval_metrics`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `eval_metrics`. Args: labels: `Tensor` with shape `[batch]`. logits: `Tensor` with shape `[batch, num_classes]`. Returns: A dict of the metrics to return from evaluation. """ predictions = tf.argmax(logits, axis=1) top_1_accuracy = tf.metrics.accuracy(labels, predictions) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) top_5_accuracy = tf.metrics.mean(in_top_5) return { 'top_1_accuracy': top_1_accuracy, 'top_5_accuracy': top_5_accuracy, } eval_metrics = (metric_fn, [labels, logits]) if FLAGS.mode == 'eval_igt' and FLAGS.igt_eval_mode == 'true': tf.logging.info('Using true param loading saver.') def scaffold_fn_true_params(): """Returns a scaffold that loads the true values into vars.""" var_mapping = {} trainable_vars = set(tf.trainable_variables()) for var in tf.global_variables(): if var in trainable_vars: var_mapping[var.op.name + '/true_param'] = var else: var_mapping[var.op.name] = var tf.logging.info('Mapping: {}'.format(var_mapping)) saver = tf.train.Saver(var_list=var_mapping, sharded=True) return tf.train.Scaffold(saver=saver) scaffold_fn = scaffold_fn_true_params return contrib_tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)
def model_fn(features, labels, mode, params, grammar): """Builds the model graph. Args: features: Dict of tensors. labels: Dict of tensors, or None if mode == INFER. mode: tf.estimator.ModeKeys execution mode. params: HParams object containing model hyperparameters. grammar: arithmetic_grammar.Grammar object. Returns: A ModelFnOps object defining predictions, loss, and train_op. """ if mode != tf_estimator.ModeKeys.PREDICT: tf.summary.text('expression_string', features['expression_string'][:10]) tf.summary.text('production_rules', tf.constant(grammar.grammar_to_string())) # Make features easier to look up. with tf.variable_scope('features'): features = { key: tf.identity(value, name=key) for key, value in six.iteritems(features) } embedding_layer = networks.partial_sequence_encoder( features=features, symbolic_properties=core.hparams_list_value( params.symbolic_properties), numerical_points=core.hparams_list_value(params.numerical_points), num_production_rules=grammar.num_production_rules, embedding_size=params.embedding_size) logits = networks.build_stacked_gru_model( embedding_layer=embedding_layer, partial_sequence_length=features['partial_sequence_length'], gru_hidden_sizes=params.gru_hidden_sizes, num_output_features=grammar.num_production_rules, bidirectional=params.bidirectional) predictions = {'logits': tf.identity(logits, name='predictions/logits')} predictions.update({ name: tf.identity(tensor, name='predictions/%s' % name) for name, tensor in six.iteritems( mask_logits(logits, features['next_production_rule_mask'])) }) predictions['next_production_rule'] = tf.argmax( predictions['masked_probabilities'], axis=1, name='predictions/next_production_rule') if mode == tf_estimator.ModeKeys.PREDICT: return tf_estimator.EstimatorSpec(mode=mode, predictions=predictions) # NOTE(leeley): The mask cannot be applied directly on logits. Because 0 # logit is still corresponding to a positive probability. Since # tf.losses.sparse_softmax_cross_entropy() only works for logits rather than # probabilities, I convert probabilities back to logits by tf.log(). Since # the probabilities for grammarly invalid production rules are 0, to avoid # numerical issue of log(0), I added a small number 1e-10. loss = tf.losses.sparse_softmax_cross_entropy( labels, tf.log(predictions['masked_probabilities'] + 1e-10)) # Configure the training op for TRAIN mode. if mode == tf_estimator.ModeKeys.TRAIN: train_op = contrib_layers.optimize_loss( loss=loss, global_step=tf.train.get_global_step(), learning_rate=core.learning_rate_decay( initial_learning_rate=params.learning_rate, decay_steps=params.learning_rate_decay_steps, decay_rate=params.learning_rate_decay_rate), optimizer=params.optimizer, summaries=contrib_layers.OPTIMIZER_SUMMARIES) return tf_estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) # Add evaluation metrics for EVAL mode. eval_metric_ops = { 'eval_loss': tf.metrics.mean(loss), 'count': contrib_metrics.count(labels), 'next_production_rule_valid_ratio': metrics.next_production_rule_valid_ratio( unmasked_probabilities_batch=predictions['unmasked_probabilities'], next_production_rule_masks=features['next_production_rule_mask']), 'next_production_rule_accuracy': metrics.next_production_rule_accuracy( next_production_rules=labels, predict_next_production_rules=predictions['next_production_rule']), } for target_length in range(1, params.max_length + 1): eval_metric_ops[ 'next_production_rule_info/length_%d' % target_length] = metrics.next_production_rule_info_batch_text_summary( expression_strings=features['expression_string'], partial_sequences=features['partial_sequence'], partial_sequence_lengths=features['partial_sequence_length'], next_production_rules=labels, unmasked_probabilities_batch=predictions[ 'unmasked_probabilities'], masked_probabilities_batch=predictions['masked_probabilities'], grammar=grammar, target_length=target_length) eval_metric_ops[ 'next_production_rule_valid_ratio/length_%d' % target_length] = metrics.next_production_rule_valid_ratio( unmasked_probabilities_batch=predictions[ 'unmasked_probabilities'], next_production_rule_masks=features[ 'next_production_rule_mask'], partial_sequence_lengths=features['partial_sequence_length'], target_length=target_length) eval_metric_ops[ 'next_production_rule_accuracy/length_%d' % target_length] = metrics.next_production_rule_accuracy( next_production_rules=labels, predict_next_production_rules=predictions[ 'next_production_rule'], partial_sequence_lengths=features['partial_sequence_length'], target_length=target_length) if params.num_expressions_per_condition > 0: with tf.variable_scope('conditional_generation'): match_ratio = tf.placeholder(tf.float32, shape=[None], name='match_ratio') fail_ratio = tf.placeholder(tf.float32, shape=[None], name='fail_ratio') eval_metric_ops.update({ 'generation_match_ratio': tf.metrics.mean(match_ratio), 'generation_fail_ratio': tf.metrics.mean(fail_ratio), }) return tf_estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
def resnet_model_fn(features, labels, mode, params): """Setup of training and eval for modified dataset using a ResNet-50. Args: features: A float32 batch of images. labels: A int32 batch of labels. mode: Specifies whether training or evaluation. params: Dictionary of parameters passed to the model. Returns: Model estimator w specifications. """ if isinstance(features, dict): features = features['feature'] mean_rgb = params['mean_rgb'] stddev_rgb = params['stddev_rgb'] features -= tf.constant(mean_rgb, shape=[1, 1, 3], dtype=features.dtype) features /= tf.constant(stddev_rgb, shape=[1, 1, 3], dtype=features.dtype) train_batch_size = params['train_batch_size'] steps_per_epoch = params['num_train_images'] / train_batch_size initial_learning_rate = params['base_learning_rate'] num_label_classes = params['num_label_classes'] network = resnet_model.resnet_50(num_classes=num_label_classes, data_format=params['data_format']) logits = network(inputs=features, is_training=(mode == tf_estimator.ModeKeys.TRAIN)) output_dir = params['output_dir'] weight_decay = params['weight_decay'] one_hot_labels = tf.one_hot(labels, num_label_classes) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels, label_smoothing=0.1) loss = cross_entropy + weight_decay * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) host_call = None if mode == tf_estimator.ModeKeys.TRAIN: global_step = tf.train.get_global_step() steps_per_epoch = params['num_train_images'] / train_batch_size current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch) learning_rate = compute_lr(current_epoch, initial_learning_rate, train_batch_size, params['lr_schedule']) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=params['momentum'], use_nesterov=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops), tf.name_scope('train'): train_op = optimizer.minimize(loss, global_step) with tf2.summary.create_file_writer(output_dir).as_default(): with tf2.summary.record_if(True): tf2.summary.scalar('loss', loss, step=global_step) tf2.summary.scalar('learning_rate', learning_rate, step=global_step) tf2.summary.scalar('current_epoch', current_epoch, step=global_step) tf2.summary.scalar('steps_per_epoch', steps_per_epoch, step=global_step) tf2.summary.scalar('weight_decay', weight_decay, step=global_step) tf.summary.all_v2_summary_ops() else: train_op = None eval_metrics = {} if mode == tf_estimator.ModeKeys.EVAL: train_op = None predictions = tf.argmax(logits, axis=1) eval_metrics['top_1_accuracy'] = tf.metrics.accuracy( labels, predictions) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) eval_metrics['top_5_accuracy'] = tf.metrics.mean(in_top_5) return tf_estimator.EstimatorSpec(training_hooks=host_call, mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metrics)
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False): hparams = hparams_lib.copy_hparams(hparams) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf_estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning( "Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls(hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None if data_parallelism is None or len( data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) # PREDICT mode if mode == tf_estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf_estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) tf.summary.scalar("learning_rate", lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf_estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf_estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf_estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: # TPU host call. Important: need to be called before remove_summaries() if hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(hparams.model_dir) else: host_call = None if hparams.warm_start_from: def scaffold_fn(): t2t_model.initialize_from_ckpt( ckpt_dir=hparams.warm_start_from, hparams=hparams) return tf.train.Scaffold() else: scaffold_fn = None t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf_estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[restore_hook, saver_hook], scaffold_fn=scaffold_fn) else: if hparams.warm_start_from: t2t_model.initialize_from_ckpt( ckpt_dir=hparams.warm_start_from, hparams=hparams) return tf_estimator.EstimatorSpec( tf_estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])
def model_fn(features, labels, mode): """BaselineModel model_fn. Args: features: `Tensor` or `dict` of `Tensor`. labels: A `dict` of `Tensor` Objects. Expects to have a key/value pair for the key self.label_column_name, "IPS_example_weights_with_label", and "IPS_example_weights_without_label". IPS stands for inverse propensity score, wherein each example is assigned a weight inversely proportionate their propensity of appearing in training distribution. Concretely, ips-weight = 1/p(x), where p(x) is the probability of x in training distribution. In "IPS_without_label", each example is given a weight as the inverse propensity score of their subgroup. For example, 1/p("Black Female"). In "IPS_with_label", each example is assigned a weight as the inverse propensity score of their subgroup and class membership. For example, 1/p("Black Female", "class 0")). mode: Defines whether this is training, evaluation or prediction. See `ModeKeys`. Currently PREDICT mode is not implemented. Returns: An instance of `tf.estimator.EstimatorSpec', which encapsulates the `mode`, `predictions`, `loss` and the `train_op`. Note that here `predictions` is either a `Tensor` or a `dict` of `Tensor` objects, representing the prediction of the bianry classification model. 'loss` is a scalar containing the loss of the step and `train_op` is the op for training. """ # Instantiates a tensor with true class labels class_labels = labels[self._label_column_name] ips_example_weights_with_label = labels[ IPS_WITH_LABEL_TARGET_COLUMN_NAME] ips_example_weights_without_label = labels[ IPS_WITHOUT_LABEL_TARGET_COLUMN_NAME] tf.logging.info('model_fn for mode: {}'.format(mode)) with tf.name_scope('model'): input_layer = tf.feature_column.input_layer( features, self._feature_columns) layer = input_layer for unit in self._hidden_units: layer = tf.layers.Dense(unit, activation=self._activation)(layer) logits = tf.layers.Dense(1)(layer) sigmoid_output = tf.nn.sigmoid(logits, name='sigmoid') class_predictions = tf.cast(tf.greater(sigmoid_output, 0.5), tf.float32) # pylint: disable=line-too-long tf.summary.histogram('class_predictions', class_predictions) if self._reweighting_type == 'IPS_with_label': example_weights = ips_example_weights_with_label elif self._reweighting_type == 'IPS_without_label': example_weights = ips_example_weights_without_label # Initializes Loss Functions loss = self._loss(class_labels, logits, example_weights) # Sets up dictionaries used for computing performance metrics predictions = { (self._label_column_name, 'class_ids'): tf.reshape(class_predictions, [-1]), (self._label_column_name, 'logistic'): tf.reshape(sigmoid_output, [-1]) } class_id_kwargs = { 'labels': class_labels, 'predictions': class_predictions } logistics_kwargs = { 'labels': class_labels, 'predictions': sigmoid_output } # EVAL Mode if mode == tf_estimator.ModeKeys.EVAL: with tf.name_scope('eval_metrics'): eval_metric_ops = { 'accuracy': tf.metrics.accuracy(**class_id_kwargs), 'precision': tf.metrics.precision(**class_id_kwargs), 'recall': tf.metrics.recall(**class_id_kwargs), 'fp': tf.metrics.false_positives(**class_id_kwargs), 'fn': tf.metrics.false_negatives(**class_id_kwargs), 'tp': tf.metrics.true_positives(**class_id_kwargs), 'tn': tf.metrics.true_negatives(**class_id_kwargs), 'fpr': contrib_metrics.streaming_false_positive_rate( **class_id_kwargs), # pylint: disable=line-too-long 'fnr': contrib_metrics.streaming_false_negative_rate( **class_id_kwargs), # pylint: disable=line-too-long 'auc': tf.metrics.auc(curve='ROC', **logistics_kwargs), 'aucpr': tf.metrics.auc(curve='PR', **logistics_kwargs) } # EstimatorSpec object for evaluation estimator_spec = tf_estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, eval_metric_ops=eval_metric_ops) # TRAIN Mode if mode == tf_estimator.ModeKeys.TRAIN: train_op_primary = contrib_layers.optimize_loss( loss=loss, learning_rate=self._learning_rate, global_step=contrib_framework.get_global_step(), optimizer=self._optimizer) estimator_spec = tf_estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op_primary) return estimator_spec
def model_fn(features, labels, mode): """BaselineModel model_fn. Args: features: `Tensor` or `dict` of `Tensor`. labels: A `dict` of `Tensor` Objects. Expects to have a key/value pair for the key self.label_column_name. mode: Defines whether this is training, evaluation or prediction. See `ModeKeys`. Currently PREDICT mode is not implemented. Returns: An instance of `tf.estimator.EstimatorSpec', which encapsulates the `mode`, `predictions`, `loss` and the `train_op`. Note that here `predictions` is either a `Tensor` or a `dict` of `Tensor` objects, representing the prediction of the bianry classification model. 'loss` is a scalar containing the loss of the step and `train_op` is the op for training. """ # Instantiates a tensor with true class labels class_labels = labels[self._label_column_name] tf.logging.info('model_fn for mode: {}'.format(mode)) with tf.name_scope('model'): input_layer = tf.feature_column.input_layer(features, self._feature_columns) layer = input_layer for unit in self._hidden_units: layer = tf.layers.Dense(unit, activation=self._activation)(layer) logits = tf.layers.Dense(1)(layer) sigmoid_output = tf.nn.sigmoid(logits, name='sigmoid') class_predictions = tf.cast(tf.greater(sigmoid_output, 0.5), tf.float32) tf.summary.histogram('class_predictions', class_predictions) # Initializes Loss Functions loss = self._loss(class_labels, logits) # Sets up dictionaries used for computing performance metrics predictions = { (self._label_column_name, 'class_ids'): tf.reshape(class_predictions, [-1]), (self._label_column_name, 'logistic'): tf.reshape(sigmoid_output, [-1]) } class_id_kwargs = { 'labels': class_labels, 'predictions': class_predictions } logistics_kwargs = {'labels': class_labels, 'predictions': sigmoid_output} # EVAL Mode if mode == tf_estimator.ModeKeys.EVAL: with tf.name_scope('eval_metrics'): eval_metric_ops = { 'accuracy': tf.metrics.accuracy(**class_id_kwargs), 'precision': tf.metrics.precision(**class_id_kwargs), 'recall': tf.metrics.recall(**class_id_kwargs), 'fp': tf.metrics.false_positives(**class_id_kwargs), 'fn': tf.metrics.false_negatives(**class_id_kwargs), 'tp': tf.metrics.true_positives(**class_id_kwargs), 'tn': tf.metrics.true_negatives(**class_id_kwargs), 'fpr': contrib_metrics.streaming_false_positive_rate(**class_id_kwargs), # pylint: disable=line-too-long 'fnr': contrib_metrics.streaming_false_negative_rate(**class_id_kwargs), # pylint: disable=line-too-long 'auc': tf.metrics.auc(curve='ROC', **logistics_kwargs), 'aucpr': tf.metrics.auc(curve='PR', **logistics_kwargs) } # EstimatorSpec object for evaluation estimator_spec = tf_estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, eval_metric_ops=eval_metric_ops) # TRAIN Mode if mode == tf_estimator.ModeKeys.TRAIN: train_op_primary = contrib_layers.optimize_loss( loss=loss, learning_rate=self._learning_rate, global_step=contrib_framework.get_global_step(), optimizer=self._optimizer) estimator_spec = tf_estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op_primary) return estimator_spec
def model_function(features, labels, mode, params, embeddings): """A model function satisfying the tf.estimator API. Args: features: Dictionary of feature tensors with keys: - question_tok: <string> [batch_size, max_question_len] - context_tok: <string> [batch_size, max_num_context, max_context_len] - question_tok_len: <int32> [batch_size] - num_context: <int32> [batch_size] - context_tok_len: <int32> [batch_size] - question_tok_wid: <int32> [batch_size, max_question_len] - context_tok_wid: <int32> [batch_size, max_num_context, max_context_len] - long_answer_indices: <int32> [batch_size] labels: <int32> [batch_size] for answer index (-1 = NULL). mode: One of the keys from tf.estimator.ModeKeys. params: Dictionary of hyperparameters. embeddings: An embedding_utils.PretrainedWordEmbeddings object. Returns: estimator_spec: A tf.estimator.EstimatorSpec object. """ del params # Unused. if mode == tf_estimator.ModeKeys.PREDICT: # Add a dummy batch dimension if we are exporting the predictor. features = {k: tf.expand_dims(v, 0) for k, v in features.items()} embedding_weights, embedding_scaffold = embeddings.get_params( trainable=False) # Features. question_tok_len = features["question_tok_len"] question_tok_wid = features["question_tok_wid"] context_tok_wid = features["context_tok_wid"] num_context = features["num_context"] context_tok_len = features["context_tok_len"] # Truncate the contexts and labels to a certain maximum length. context_tok_wid, num_context, context_tok_len = ( nq_long_utils.truncate_contexts(context_token_ids=context_tok_wid, num_contexts=num_context, context_len=context_tok_len, max_contexts=FLAGS.max_contexts, max_context_len=FLAGS.max_context_len)) non_null_context_scores = nq_long_decatt_model.build_model( question_tok_wid=question_tok_wid, question_lens=question_tok_len, context_tok_wid=context_tok_wid, context_lens=context_tok_len, embedding_weights=embedding_weights, mode=mode) # Mask out contexts that are padding. num_context_mask = tf.log( tf.sequence_mask(num_context, tensor_utils.shape(non_null_context_scores, 1), dtype=tf.float32)) non_null_context_scores += num_context_mask # <float> [batch_size, 1] null_score = tf.zeros([tf.shape(question_tok_wid)[0], 1]) # Offset everything by 1 to account for null context. # [batch_size, 1 + max_contexts] context_scores = tf.concat([null_score, non_null_context_scores], 1) if mode != tf_estimator.ModeKeys.PREDICT: labels = nq_long_utils.truncate_labels(labels, FLAGS.max_contexts) # In the data, NULL is given index -1 but this is not compatible with # softmax so shift by 1. labels = labels + 1 # Reweight null examples. weights = nq_long_utils.compute_null_weights(labels, FLAGS.null_weight) # When computing the loss we take only the first label. loss_labels = labels[:, 0] # [] loss = tf.losses.sparse_softmax_cross_entropy(labels=loss_labels, logits=context_scores, weights=weights) optimizer = tf.train.AdagradOptimizer( learning_rate=FLAGS.learning_rate) train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step()) # <int32> [batch_size] eval_predictions = tf.to_int32(tf.argmax(context_scores, 1)) non_null_match, non_null_gold, non_null_predictions = ( nq_long_utils.compute_match_stats(eval_predictions, labels)) precision, precision_op = (tf.metrics.mean( non_null_match, weights=non_null_predictions)) recall, recall_op = (tf.metrics.mean(non_null_match, weights=non_null_gold)) f1, f1_op = (nq_long_utils.f1_metric(precision=precision, precision_op=precision_op, recall=recall, recall_op=recall_op)) # Bogus metric until we figure out how to connect Ming Wei's eval code. eval_metric_ops = { "precision": (precision, precision_op), "recall": (recall, recall_op), "f1": (f1, f1_op) } else: loss = None train_op = None eval_metric_ops = {} # In the export, we never predict NULL since the eval metric will compute the # best possible F1. export_long_answer_idx = tf.to_int32(tf.argmax(non_null_context_scores, 1)) export_long_answer_score = tf.reduce_max(non_null_context_scores, 1) predictions = dict(idx=export_long_answer_idx, score=export_long_answer_score) if mode == tf_estimator.ModeKeys.PREDICT: # Remove the dummy batch dimension if we are exporting the predictor. predictions = {k: tf.squeeze(v, 0) for k, v in predictions.items()} estimator_spec = tf_estimator.EstimatorSpec( mode=mode, loss=loss, predictions=predictions, train_op=train_op, eval_metric_ops=eval_metric_ops, scaffold=embedding_scaffold) return estimator_spec
def _model_fn(features, labels, mode, params): """Constructs the model function. Args: features: Dictionary of input features. labels: Tensor of labels if mode is `TRAIN` or `EVAL`, otherwise `None`. mode: ModeKey object (`TRAIN` or `EVAL`). params: Parameter dictionary passed from the Estimator object. Returns: An EstimatorSpec object that encapsulates the model and its serving configurations. """ del params # Unused. def process_images(images): """Closure for processing images with fixed metadata.""" return process.process(images, features['red_gain'], features['blue_gain'], features['cam2rgb']) denoised_img = inference_fn(features['noisy_img'], features['variance']) noisy_img = process_images(features['noisy_img']) denoised_img = process_images(denoised_img) truth_img = process_images(labels) if mode in [tf_estimator.ModeKeys.TRAIN, tf_estimator.ModeKeys.EVAL]: loss = tf.losses.absolute_difference(truth_img, denoised_img) else: loss = None if mode == tf_estimator.ModeKeys.TRAIN: optimizer = tf.train.AdamOptimizer( learning_rate=hparams.learning_rate) train_op = contrib_layers.optimize_loss( loss=loss, global_step=tf.train.get_global_step(), learning_rate=None, optimizer=optimizer, name='') # Prevents scope prefix. else: train_op = None if mode == tf_estimator.ModeKeys.EVAL: eval_metric_ops = {'PSNR': psnr(truth_img, denoised_img)} def summary(images, name): """As a hack, saves image summaries by adding to `eval_metric_ops`.""" images = tf.saturate_cast(images * 255 + 0.5, tf.uint8) eval_metric_ops[name] = (tf.summary.image(name, images, max_outputs=2), tf.no_op()) summary(noisy_img, 'Noisy') summary(denoised_img, 'Denoised') summary(truth_img, 'Truth') diffs = (denoised_img - truth_img + 1.0) / 2.0 summary(diffs, 'Diffs') else: eval_metric_ops = None return tf_estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops)
def model_fn_w_pruning(features, labels, mode, params): """The model_fn for ResNet-50 with pruning. Args: features: A float32 batch of images. labels: A int32 batch of labels. mode: Specifies whether training or evaluation. params: parameters passed to the eval function. Returns: A EstimatorSpec for the model """ task = params["task"] if task in ["pie_dataset_gen", "imagenet_training", "imagenet_predictions"]: images = features["image_raw"] labels = features["label"] else: images = features if task in [ "pie_dataset_gen", "robustness_imagenet_c", "robustness_imagenet_a", "ckpt_prediction" ]: human_labels = features["human_label"] mean_rgb = params["mean_rgb"] stddev_rgb = params["stddev_rgb"] # Normalize the image to zero mean and unit variance. images -= tf.constant(mean_rgb, shape=[1, 1, 3], dtype=images.dtype) images /= tf.constant(stddev_rgb, shape=[1, 1, 3], dtype=images.dtype) network = resnet_model.resnet_50( num_classes=params["num_label_classes"], pruning_method=params["pruning_method"], data_format="channels_last") logits = network( inputs=images, is_training=(mode == tf_estimator.ModeKeys.TRAIN)) one_hot_labels = tf.one_hot(labels, params["num_label_classes"]) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels, label_smoothing=params["label_smoothing"]) # Add weight decay to the loss for non-batch-normalization variables. loss = cross_entropy + params["weight_decay"] * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if "batch_normalization" not in v.name ]) # we run predictions on gpu since ordering is very important and # thus we need to run with batch size 1 (not enabled on tpu) if mode == tf_estimator.ModeKeys.PREDICT: train_op = None eval_metrics = None predicted_probability = tf.cast( tf.reduce_max(tf.nn.softmax(logits, name="softmax"), axis=1), tf.float32) _, top_5_indices = tf.nn.top_k(tf.to_float(logits), k=5) predictions = { "predictions": tf.argmax(logits, axis=1), "true_class": labels, "predicted_probability": predicted_probability, "top_5_indices": top_5_indices } if mode == tf_estimator.ModeKeys.TRAIN: train_op = train_function(params, loss) eval_metrics = None predictions = None if mode == tf_estimator.ModeKeys.EVAL: train_op = None predictions = None params_eval = { "num_label_classes": params["num_label_classes"], "log_class_level_summaries": False } eval_metrics = class_level_metrics.create_eval_metrics( labels, logits, human_labels, params_eval) return tf_estimator.EstimatorSpec( predictions=predictions, mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metrics)
def model_function(features, labels, mode, params, embeddings): """A model function satisfying the tf.estimator API. Args: features: Dictionary of feature tensors with keys: - question: <string> [batch_size, max_question_len] - question_len: <int32> [batch_size] - question_cid: <int32> [batch_size, max_question_len, max_chars] - question_wid: <int32> [batch_size, max_question_len] - context: <string> [batch_size, max_context_len] - context_len: <int32> [batch_size] - context_cid: <int32> [batch_size, max_context_len, max_chars] - context_wid: <int32> [batch_size, max_context_len] - answer_start: <int32> [batch_size] - answer_end: <int32> [batch_size] labels: Pair of tensors containing the answer start and answer end. mode: One of the keys from tf.estimator.ModeKeys. params: Unused parameter dictionary. embeddings: An embedding_utils.PretrainedWordEmbeddings object. Returns: estimator_spec: A tf.estimator.EstimatorSpec object. """ del params if mode == tf_estimator.ModeKeys.PREDICT: # Add a dummy batch dimension if we are exporting the predictor. features = {k: tf.expand_dims(v, 0) for k, v in features.items()} embedding_weights, embedding_scaffold = embeddings.get_params( trainable=False) def _embed(prefix): """Embed the input text based and word and character IDs.""" word_emb = tf.nn.embedding_lookup(embedding_weights, features[prefix + "_wid"]) char_emb = common_layers.character_cnn( char_ids=features[prefix + "_cid"], emb_size=FLAGS.char_emb_size, kernel_width=FLAGS.char_kernel_width, num_filters=FLAGS.num_char_filters) concat_emb = tf.concat([word_emb, char_emb], -1) if mode == tf_estimator.ModeKeys.TRAIN: concat_emb = tf.nn.dropout(concat_emb, 1.0 - FLAGS.dropout_ratio) return concat_emb with tf.variable_scope("embed"): # [batch_size, max_question_len, hidden_size] question_emb = _embed("question") with tf.variable_scope("embed", reuse=True): # [batch_size, max_context_len, hidden_size] context_emb = _embed("context") # [batch_size, max_context_len] start_logits, end_logits = document_reader.score_endpoints( question_emb=question_emb, question_len=features["question_len"], context_emb=context_emb, context_len=features["context_len"], hidden_size=FLAGS.hidden_size, num_layers=FLAGS.num_layers, dropout_ratio=FLAGS.dropout_ratio, mode=mode, use_cudnn=False if mode == tf_estimator.ModeKeys.PREDICT else None) if mode != tf_estimator.ModeKeys.PREDICT: # [batch_size] start_labels, end_labels = labels # Since we truncate long contexts, some of the labels will not be # recoverable. In that case, we mask these invalid labels. valid_start_labels = tf.less(start_labels, features["context_len"]) valid_end_labels = tf.less(end_labels, features["context_len"]) tf.summary.histogram("valid_start_labels", tf.to_float(valid_start_labels)) tf.summary.histogram("valid_end_labels", tf.to_float(valid_end_labels)) dummy_labels = tf.zeros_like(start_labels) # [] start_loss = tf.losses.sparse_softmax_cross_entropy( labels=tf.where(valid_start_labels, start_labels, dummy_labels), logits=start_logits, weights=tf.to_float(valid_start_labels), reduction=tf.losses.Reduction.MEAN) end_loss = tf.losses.sparse_softmax_cross_entropy( labels=tf.where(valid_end_labels, end_labels, dummy_labels), logits=end_logits, weights=tf.to_float(valid_end_labels), reduction=tf.losses.Reduction.MEAN) loss = start_loss + end_loss else: loss = None if mode == tf_estimator.ModeKeys.TRAIN: optimizer = tf.train.AdamOptimizer() gradients, variables = list(zip(*optimizer.compute_gradients(loss))) gradients, _ = tf.clip_by_global_norm(gradients, 5.0) train_op = optimizer.apply_gradients( grads_and_vars=list(zip(gradients, variables)), global_step=tf.train.get_global_step()) else: # Don't build the train_op unnecessarily, since the ADAM variables can cause # problems with loading checkpoints on CPUs. train_op = None batch_size, max_context_len = tensor_utils.shape(features["context_wid"]) tf.summary.histogram("batch_size", batch_size) tf.summary.histogram("non_padding", features["context_len"] / max_context_len) # [batch_size], [batch_size] start_predictions, end_predictions, predicted_score = ( span_utils.max_scoring_span(start_logits, end_logits)) # [batch_size, 2] predictions = dict(start_idx=start_predictions, end_idx=(end_predictions + 1), score=predicted_score) if mode == tf_estimator.ModeKeys.PREDICT: # Remove the dummy batch dimension if we are exporting the predictor. predictions = {k: tf.squeeze(v, 0) for k, v in predictions.items()} if mode == tf_estimator.ModeKeys.EVAL: text_summary = get_text_summary(question=features["question"], context=features["context"], start_predictions=start_predictions, end_predictions=end_predictions) # TODO(kentonl): Replace this with @mingweichang's official eval script. exact_match = tf.logical_and(tf.equal(start_predictions, start_labels), tf.equal(end_predictions, end_labels)) eval_metric_ops = dict(exact_match=tf.metrics.mean(exact_match), text_summary=(text_summary, tf.no_op())) else: eval_metric_ops = None estimator_spec = tf_estimator.EstimatorSpec( mode=mode, loss=loss, predictions=predictions, train_op=train_op, eval_metric_ops=eval_metric_ops, scaffold=embedding_scaffold) return estimator_spec
def resnet_model_fn_w_pruning(features, labels, mode, params): """The model_fn for ResNet-50 with pruning. Args: features: A float32 batch of images. labels: A int32 batch of labels. mode: Specifies whether training or evaluation. params: Dictionary of parameters passed to the model. Returns: A TPUEstimatorSpec for the model """ width = 1. if FLAGS.width <= 0 else FLAGS.width if isinstance(features, dict): features = features['feature'] if FLAGS.data_format == 'channels_first': assert not FLAGS.transpose_input # channels_first only for GPU features = tf.transpose(features, [0, 3, 1, 2]) if FLAGS.transpose_input and mode != tf_estimator.ModeKeys.PREDICT: features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC # Normalize the image to zero mean and unit variance. features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype) features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype) pruning_method = params['pruning_method'] use_tpu = params['use_tpu'] log_alpha_threshold = params['log_alpha_threshold'] def build_network(): """Construct the network in the graph.""" model_pruning_method = pruning_method if pruning_method == 'scratch': model_pruning_method = 'threshold' network = resnet_model.resnet_v1_( resnet_depth=FLAGS.resnet_depth, num_classes=FLAGS.num_label_classes, # we need to construct the model with the pruning masks, but they won't # be updated if we're doing scratch training pruning_method=model_pruning_method, init_method=FLAGS.init_method, width=width, prune_first_layer=FLAGS.prune_first_layer, prune_last_layer=FLAGS.prune_last_layer, data_format=FLAGS.data_format, end_sparsity=FLAGS.end_sparsity, clip_log_alpha=FLAGS.clip_log_alpha, log_alpha_threshold=log_alpha_threshold, weight_decay=FLAGS.weight_decay) return network(inputs=features, is_training=(mode == tf_estimator.ModeKeys.TRAIN)) if FLAGS.precision == 'bfloat16': with contrib_tpu.bfloat16_scope(): logits = build_network() logits = tf.cast(logits, tf.float32) elif FLAGS.precision == 'float32': logits = build_network() if mode == tf_estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf_estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf_estimator.export.PredictOutput(predictions) }) output_dir = params['output_dir'] # pylint: disable=unused-variable # Calculate loss, which includes softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes) # make sure we reuse the same label smoothing parameter is we're doing # scratch / lottery ticket experiments. label_smoothing = FLAGS.label_smoothing if FLAGS.pruning_method == 'scratch': label_smoothing = float(FLAGS.load_mask_dir.split('/')[15]) loss = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=one_hot_labels, label_smoothing=label_smoothing) # Add regularization loss term loss += tf.losses.get_regularization_loss() if pruning_method == 'variational_dropout': reg_loss = utils.variational_dropout_dkl_loss( reg_scalar=FLAGS.reg_scalar, start_reg_ramp_up=FLAGS.sparsity_begin_step, end_reg_ramp_up=FLAGS.sparsity_end_step, warm_up=FLAGS.is_warm_up, use_tpu=use_tpu) loss += reg_loss tf.losses.add_loss(reg_loss, loss_collection=tf.GraphKeys.LOSSES) elif pruning_method == 'l0_regularization': reg_loss = utils.l0_regularization_loss( reg_scalar=FLAGS.reg_scalar, start_reg_ramp_up=FLAGS.sparsity_begin_step, end_reg_ramp_up=FLAGS.sparsity_end_step, warm_up=FLAGS.is_warm_up, use_tpu=use_tpu) loss += reg_loss tf.losses.add_loss(reg_loss, loss_collection=tf.GraphKeys.LOSSES) host_call = None if mode == tf_estimator.ModeKeys.TRAIN: host_call, train_op = train_function(pruning_method, loss, output_dir, use_tpu) else: train_op = None eval_metrics = None if mode == tf_estimator.ModeKeys.EVAL: def metric_fn(labels, logits): """Calculate eval metrics.""" logging.info('In metric function') eval_metrics = {} predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) eval_metrics['top_5_eval_accuracy'] = tf.metrics.mean(in_top_5) eval_metrics['eval_accuracy'] = tf.metrics.accuracy( labels=labels, predictions=predictions) return eval_metrics def vd_metric_fn(labels, logits, global_sparsity): eval_metrics = metric_fn(labels, logits) eval_metrics['global_sparsity'] = tf.metrics.mean(global_sparsity) return eval_metrics tensors = [labels, logits] metric_function = metric_fn if FLAGS.pruning_method == 'variational_dropout': batch_size = labels.shape[0] ones = tf.ones([batch_size, 1]) mask_metrics = utils.add_vd_pruning_summaries( threshold=FLAGS.log_alpha_threshold) tensors.append(mask_metrics['global_sparsity'] * ones) metric_function = vd_metric_fn eval_metrics = (metric_function, tensors) # define a custom scaffold function to enable initializing the mask from an # already trained checkpoint. def initialize_mask_from_ckpt(ckpt_path): """Load mask from an existing checkpoint.""" model_dir = FLAGS.output_dir already_has_ckpt = model_dir and tf.train.latest_checkpoint( model_dir) is not None if already_has_ckpt: tf.logging.info( 'Training already started on this model, not loading masks from' 'previously trained model') return reader = tf.train.NewCheckpointReader(ckpt_path) mask_names = reader.get_variable_to_shape_map().keys() mask_names = [x for x in mask_names if x.endswith('mask')] variable_map = {} for var in tf.global_variables(): var_name = var.name.split(':')[0] if var_name in mask_names: tf.logging.info('Loading mask variable from checkpoint: %s', var_name) variable_map[var_name] = var elif 'mask' in var_name: tf.logging.info( 'Cannot find mask variable in checkpoint, skipping: %s', var_name) tf.train.init_from_checkpoint(ckpt_path, variable_map) def initialize_parameters_from_ckpt(ckpt_path): """Load parameters from an existing checkpoint.""" model_dir = FLAGS.output_dir already_has_ckpt = model_dir and tf.train.latest_checkpoint( model_dir) is not None if already_has_ckpt: tf.logging.info( 'Training already started on this model, not loading masks from' 'previously trained model') return reader = tf.train.NewCheckpointReader(ckpt_path) param_names = reader.get_variable_to_shape_map().keys() param_names = [x for x in param_names if not x.endswith('mask')] variable_map = {} for var in tf.global_variables(): var_name = var.name.split(':')[0] if var_name in param_names: tf.logging.info( 'Loading parameter variable from checkpoint: %s', var_name) variable_map[var_name] = var elif 'mask' not in var_name: tf.logging.info( 'Cannot find parameter variable in checkpoint, skipping: %s', var_name) tf.train.init_from_checkpoint(ckpt_path, variable_map) if FLAGS.pruning_method == 'scratch': if FLAGS.load_mask_dir: def scaffold_fn(): initialize_mask_from_ckpt(FLAGS.load_mask_dir) if FLAGS.initial_value_checkpoint: initialize_parameters_from_ckpt( FLAGS.initial_value_checkpoint) return tf.train.Scaffold() else: raise ValueError( 'Must supply a mask directory to use scratch method') else: scaffold_fn = None return contrib_tpu.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)
def model_fn(features, labels, mode): """AdversarialReweightingModel model_fn. Args: features: `Tensor` or `dict` of `Tensor`. labels: A `dict` of `Tensor` Objects. Expects to have a key/value pair for the key self.label_column_name. mode: Defines whether this is training, evaluation or prediction. See `ModeKeys`. Currently PREDICT mode is not implemented. Returns: An instance of `tf.estimator.EstimatorSpec', which encapsulates the `mode`, `predictions`, `loss` and the `train_op`. Note that here `predictions` is either a `Tensor` or a `dict` of `Tensor` objects, representing the prediction of the bianry classification model. 'loss` is a scalar containing the loss of the step and `train_op` is the op for training. """ # Instantiates a tensor with weight for positive class examples only pos_weights = tf.cast(tf.equal(labels[self._label_column_name], 1), dtype=tf.float32) # Instantiates a tensor with true class labels class_labels = labels[self._label_column_name] # Initialize a global step variable used for alternate training current_step = self._get_or_create_global_step_var() if mode == tf_estimator.ModeKeys.EVAL: tf.logging.info('model_fn: EVAL, {}'.format(mode)) elif mode == tf_estimator.ModeKeys.TRAIN: tf.logging.info('model_fn: TRAIN, {}'.format(mode)) # Creates a DNN architecture for primary binary classification task with tf.name_scope('primary_NN'): with tf.variable_scope('primary'): input_layer = tf.feature_column.input_layer( features, self._feature_columns) h1 = tf.layers.Dense( self._primary_hidden_units[0], activation=self._activation)(input_layer) h2 = tf.layers.Dense(self._primary_hidden_units[1], activation=self._activation)(h1) logits = tf.layers.Dense(1)(h2) sigmoid_output = tf.nn.sigmoid(logits, name='sigmoid') class_predictions = tf.cast( tf.greater(sigmoid_output, 0.5), tf.float32) tf.summary.histogram('class_predictions', class_predictions) # Creates a network architecture for the adversarial regression task with tf.name_scope('adversary_NN'): with tf.variable_scope('adversary'): # Gets adversary features and features columns adversarial_features, adversary_feature_columns = self._get_adversary_features_and_feature_columns(features, labels) # pylint: disable=line-too-long adv_input_layer = tf.feature_column.input_layer( adversarial_features, adversary_feature_columns) adv_h1 = tf.layers.Dense( self._adversary_hidden_units[0])(adv_input_layer) adv_output_layer = tf.layers.Dense(1, use_bias=True)(adv_h1) example_weights = tf.cond( tf.greater(current_step, self._pretrain_steps), true_fn=lambda: self._compute_example_weights( adv_output_layer), false_fn=lambda: tf.ones_like(class_labels)) # Adds summary variables to tensorboard with tf.name_scope('example_weights'): tf.summary.histogram('example_weights', example_weights) tf.summary.histogram('label', class_labels) # Initializes Loss Functions primary_loss = self._primary_loss(class_labels, logits, example_weights) adversary_loss = self._adversary_loss(class_labels, logits, pos_weights, example_weights, self._adversary_loss_type) # Sets up dictionaries used for computing performance metrics predictions = { (self._label_column_name, 'class_ids'): tf.reshape(class_predictions, [-1]), (self._label_column_name, 'logistic'): tf.reshape(sigmoid_output, [-1]), ('example_weights'): tf.reshape(example_weights, [-1]) } class_id_kwargs = { 'labels': class_labels, 'predictions': class_predictions } logistics_kwargs = { 'labels': class_labels, 'predictions': sigmoid_output } # EVAL Mode if mode == tf_estimator.ModeKeys.EVAL: with tf.name_scope('eval_metrics'): eval_metric_ops = { 'accuracy': tf.metrics.accuracy(**class_id_kwargs), 'precision': tf.metrics.precision(**class_id_kwargs), 'recall': tf.metrics.recall(**class_id_kwargs), 'fp': tf.metrics.false_positives(**class_id_kwargs), 'fn': tf.metrics.false_negatives(**class_id_kwargs), 'tp': tf.metrics.true_positives(**class_id_kwargs), 'tn': tf.metrics.true_negatives(**class_id_kwargs), 'fpr': contrib_metrics.streaming_false_positive_rate( **class_id_kwargs), # pylint: disable=line-too-long 'fnr': contrib_metrics.streaming_false_negative_rate( **class_id_kwargs), # pylint: disable=line-too-long 'auc': tf.metrics.auc(curve='ROC', **logistics_kwargs), 'aucpr': tf.metrics.auc(curve='PR', **logistics_kwargs) } # EstimatorSpec object for evaluation estimator_spec = tf_estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=primary_loss, eval_metric_ops=eval_metric_ops) # TRAIN Mode if mode == tf_estimator.ModeKeys.TRAIN: # Filters trainable variables for each task all_trainable_vars = tf.trainable_variables() primary_trainable_vars = [ v for v in all_trainable_vars if 'primary' in v.op.name ] adversary_trainable_vars = [ v for v in all_trainable_vars if 'adversary' in v.op.name ] # TRAIN_OP for adversary DNN train_op_adversary = contrib_layers.optimize_loss( loss=adversary_loss, variables=adversary_trainable_vars, global_step=contrib_framework.get_global_step(), learning_rate=self._adversary_learning_rate, optimizer=self._optimizer) # TRAIN_OP for primary DNN train_op_primary = contrib_layers.optimize_loss( loss=primary_loss, variables=primary_trainable_vars, global_step=contrib_framework.get_global_step(), learning_rate=self._primary_learning_rate, optimizer=self._optimizer) # Upto ``pretrain_steps'' trains primary only. # Beyond ``pretrain_steps'' alternates between primary and adversary. estimator_spec = tf_estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=primary_loss + adversary_loss, train_op=tf.cond( tf.greater(current_step, self._pretrain_steps), true_fn=lambda: tf.group( [train_op_primary, train_op_adversary]), # pylint: disable=line-too-long false_fn=lambda: tf.group([train_op_primary]))) return estimator_spec