示例#1
0
    def model_fn(features, mode, params):
        """The `model_fn` for TPUEstimator."""
        label_ids = None
        if mode != tf_estimator.ModeKeys.PREDICT:
            label_ids = features["label"]

        model_config = runner_config["model_config"]
        loss, logits = create_model(model, model_config, features, mode,
                                    runner_config["name"])

        if mode == tf_estimator.ModeKeys.TRAIN:
            train_op = create_optimizer(loss, runner_config, params)
            return tf_estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                     loss=loss,
                                                     train_op=train_op)
        elif mode == tf_estimator.ModeKeys.EVAL:
            if not runner_config["model_config"]["multilabel"]:
                metric_fn = metric_functions.classification_metric
            else:
                metric_fn = metric_functions.labeling_metric

            eval_metrics = (metric_fn, [loss, label_ids, logits])
            return tf_estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                     loss=loss,
                                                     eval_metrics=eval_metrics)
        elif mode == tf_estimator.ModeKeys.PREDICT:
            predictions = {"logits": logits}
            if not runner_config["model_config"]["multilabel"]:
                predictions["predictions"] = tf.nn.softmax(logits)
            else:
                predictions["predictions"] = tf.math.sigmoid(logits)
            return tf_estimator.EstimatorSpec(mode=mode,
                                              predictions=predictions)
        else:
            assert False, "Expected to be called in TRAIN, EVAL, or PREDICT mode."
示例#2
0
    def estimator_spec_predict(self, features, mesh, mesh_impl, use_tpu):
        mtf_samples = mtf.anonymize(self.sample(features, mesh))
        lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
        outputs = lowering.export_to_tf_tensor(mtf_samples)
        if self.has_input:
            ndims = len(outputs.shape.as_list())
            actual_batch_size = tf.shape(features["inputs"])[0]
            outputs = tf.slice(outputs, [0] * ndims,
                               [actual_batch_size] + [-1] * (ndims - 1))
        predictions = {"outputs": outputs}
        if features.get("infer_targets") is not None:
            predictions["infer_targets"] = features["infer_targets"]

        if features.get("inputs") is not None:
            predictions["inputs"] = features["inputs"]

        if use_tpu:
            t2t_model.remove_summaries()
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf_estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])
        else:
            return tf_estimator.EstimatorSpec(
                tf_estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])
示例#3
0
    def model_function(features, labels, mode, params):
        """Builds the `tf.estimator.EstimatorSpec` to train/eval with."""
        is_train = mode == tf_estimator.ModeKeys.TRAIN
        logits = predict(is_train, embeddings, features["premise"],
                         features["hypothesis"])

        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.to_int32(labels), logits=logits)
        loss = tf.reduce_mean(loss)
        if mode == tf_estimator.ModeKeys.TRAIN:
            train_op = get_train_op(loss)
        else:
            # Don't build the train_op unnecessarily, since the ADAM variables can
            # cause problems with loading checkpoints on CPUs.
            train_op = None
        metrics = dict(accuracy=tf.metrics.accuracy(
            tf.argmax(logits, 1, output_type=tf.int32), tf.to_int32(labels)))

        checkpoint_file = FLAGS.checkpoint_file
        if checkpoint_file is None:
            scaffold = None
        else:
            saver = tf.train.Saver(tf.trainable_variables())

            def _init_fn(_, sess):
                saver.restore(sess, checkpoint_file)

            scaffold = tf.train.Scaffold(init_fn=_init_fn)

        return tf_estimator.EstimatorSpec(mode=mode,
                                          scaffold=scaffold,
                                          loss=loss,
                                          predictions=None,
                                          train_op=train_op,
                                          eval_metric_ops=metrics)
示例#4
0
def model_fn(features, labels, mode, params):
  """Model function."""
  reader_beam_size = params["reader_beam_size"]
  num_classes = params["num_classes"]
  if mode == tf_estimator.ModeKeys.PREDICT:
    retriever_beam_size = reader_beam_size
  else:
    retriever_beam_size = params["retriever_beam_size"]
  assert reader_beam_size <= retriever_beam_size

  with tf.device("/cpu:0"):
    retriever_outputs = orqa_model.retrieve(
        features=features,
        retriever_beam_size=retriever_beam_size,
        mode=mode,
        params=params)

  with tf.variable_scope("reader"):
    # [reader_beam_size, num_classes]
    final_logits = read(
        features=features,
        retriever_logits=retriever_outputs.logits[:reader_beam_size],
        blocks=retriever_outputs.blocks[:reader_beam_size],
        mode=mode,
        params=params)

  # [reader_beam_size]
  # We pick the most confident prediction amongst all retrievals.
  predictions = tf.argmax(
      tf.reshape(final_logits, [reader_beam_size * num_classes]))
  predictions = tf.math.floormod(predictions, num_classes)

  if mode == tf_estimator.ModeKeys.PREDICT:
    loss = None
    train_op = None
    eval_metric_ops = None
  else:
    labels = tf.cast(labels, tf.int32)

    eval_metric_ops = compute_eval_metrics(
        labels=labels, predictions=predictions)

    loss = compute_loss(labels, final_logits)

    train_op = optimization.create_optimizer(
        loss=loss,
        init_lr=params["learning_rate"],
        num_train_steps=params["num_train_steps"],
        num_warmup_steps=min(10000, max(100,
                                        int(params["num_train_steps"] / 10))),
        use_tpu=False)

  return tf_estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      predictions={"answer": predictions},
      eval_metric_ops=eval_metric_ops)
示例#5
0
 def model_fn(features, labels, mode):
   """The model function for creating an Estimtator."""
   del labels
   input_count = tf.reduce_sum(
       tf.to_int32(tf.greater(features["input_refs"][:, :, 1],
                              features["input_refs"][:, :, 0])))
   tf.summary.scalar("input_count", input_count)
   loss_dict, pred_dict, areas = seq2act_model.core_graph(
       features, hparams, mode, compute_additional_loss_fn)
   if mode == tf_estimator.ModeKeys.PREDICT:
     pred_dict["sequences"] = decode_sequence(
         features, areas, hparams, decode_length,
         post_processing=FLAGS.post_processing)
     return tf_estimator.EstimatorSpec(mode, predictions=pred_dict)
   elif mode == tf_estimator.ModeKeys.EVAL:
     metrics = {}
     _eval(metrics, pred_dict, loss_dict, features,
           areas, compute_seq_accuracy,
           hparams,
           metric_types=FLAGS.metric_types.split(","),
           decode_length=decode_length)
     if compute_additional_metric_fn:
       compute_additional_metric_fn(metrics, pred_dict, features)
     return tf_estimator.EstimatorSpec(
         mode, loss=loss_dict["total_loss"], eval_metric_ops=metrics)
   else:
     assert mode == tf_estimator.ModeKeys.TRAIN
     loss = loss_dict["total_loss"]
     for loss_name in loss_dict:
       if loss_name == "total_loss":
         continue
       if loss_name.endswith("losses"):
         continue
       tf.summary.scalar(loss_name, loss_dict[loss_name])
     step_num = tf.to_float(tf.train.get_global_step())
     schedule_string = hparams.learning_rate_schedule
     names = schedule_string.split("*")
     names = [name.strip() for name in names if name.strip()]
     ret = tf.constant(1.0)
     for name in names:
       ret *= learning_rate.learning_rate_factor(name, step_num, hparams)
     train_op = optimize.optimize(loss, ret, hparams)
     return tf_estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
def _build_estimator_spec(losses, trainer_params, mode, use_tpu=False):
    """Builds an EstimatorSpec/TPUEstimatorSpec based on trainer_params.

  Args:
    losses: A dictionary of {string: tf.Tensor} containing the various losses.
      The keys will be used as display names for the summaries, the values will
      be summed up to obtain the total loss, which is to be minimized.
    trainer_params: A ParameterContainer object with parameters relevant to the
      training.
    mode: One of tf.estimator.ModeKeys: TRAIN, PREDICT or EVAL.
    use_tpu: A boolean, if True, a TPU-compatible version of EstimatorSpec will
      be built.

  Returns:
    A EstimatorSpec or a TPUEstimatorSpec object.
  """
    if mode == tf_estimator.ModeKeys.TRAIN:
        total_loss = 0.0
        for loss_name, loss in six.iteritems(losses):
            if not use_tpu:
                tf.summary.scalar('Loss/%s' % loss_name, loss)
            total_loss += loss

        learning_rate = trainer_params.learning_rate
        maybe_summary.scalar('Learning Rate', learning_rate)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                           beta1=0.9)
        optimizer = contrib_estimator.clip_gradients_by_norm(
            optimizer, trainer_params.clip_gradients)

        if use_tpu:
            optimizer = tf.tpu.CrossShardOptimizer(optimizer)

        train_op = optimizer.minimize(total_loss,
                                      global_step=tf.train.get_global_step())
    else:
        total_loss = None
        train_op = None

    if use_tpu:
        estimator_spec = tf_estimator.tpu.TPUEstimatorSpec(
            mode=tf_estimator.ModeKeys.TRAIN,
            loss=total_loss,
            train_op=train_op)
    else:
        estimator_spec = tf_estimator.EstimatorSpec(
            mode=tf_estimator.ModeKeys.TRAIN,
            loss=total_loss,
            train_op=train_op)

    return estimator_spec
示例#7
0
    def estimator_spec_eval(self, features, logits, labels, loss, restore_hook,
                            use_tpu):
        """Construct EstimatorSpec for EVAL mode."""
        hparams = self.hparams
        problem = hparams.problem
        if logits.get_shape().ndims == 3:
            logits = tf.expand_dims(tf.expand_dims(logits, 2), 3)

        # Support for multiproblem
        task_list = [problem]
        if hasattr(problem, "task_list"):
            task_list = problem.task_list

        eval_metrics_fns = metrics.create_evaluation_metrics(
            task_list, hparams)

        if use_tpu:

            def metric_fn(tf_logits, labels):
                with tf.device("cpu:0"), mtf.utils.outside_all_rewrites():
                    eval_metrics = {}
                    for metric_name, metric_fn in six.iteritems(
                            eval_metrics_fns):
                        if metric_name.split(
                                "/")[-1] not in t2t_model.TPU_METRIC_BLACKLIST:
                            eval_metrics[metric_name] = metric_fn(
                                tf_logits, None, tf.identity(labels))
                    return eval_metrics

            return tpu_estimator.TPUEstimatorSpec(
                tf_estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=loss,
                eval_metrics=(metric_fn, [logits, labels]))
        else:
            eval_metrics = {}
            predictions = {"predictions": logits}
            for metric_name, metric_fn in six.iteritems(eval_metrics_fns):
                eval_metrics[metric_name] = metric_fn(logits, features,
                                                      features["targets"])

            return tf_estimator.EstimatorSpec(tf_estimator.ModeKeys.EVAL,
                                              predictions=predictions,
                                              eval_metric_ops=eval_metrics,
                                              evaluation_hooks=[restore_hook],
                                              loss=loss)
示例#8
0
    def _model_fn(features, labels, mode):
        """A model_fn that uses a mock TF-Hub module."""
        del labels

        spec = hub.create_module_spec(text_module_fn)
        embedding = hub.Module(spec)
        if register_module:
            hub.register_module_for_export(embedding, _EXPORT_MODULE_NAME)
        predictions = embedding(features[_TEXT_FEATURE_NAME])
        loss = tf.constant(0.0)

        global_step = tf.compat.v1.train.get_global_step()
        train_op = tf.compat.v1.assign_add(global_step, 1)

        return tf_estimator.EstimatorSpec(mode=mode,
                                          predictions=predictions,
                                          loss=loss,
                                          train_op=train_op)
示例#9
0
 def _simple_model_fn(self, features, labels, mode, params):
     logits = tf.squeeze(tf.layers.dense(features, 1))
     loss = tf.reduce_mean(
         tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.to_float(labels),
                                                 logits=logits))
     optimizer = tf.train.GradientDescentOptimizer(0.1)
     if params["use_tpu"]:
         optimizer = tf.tpu.CrossShardOptimizer(optimizer)
     train_op = optimizer.minimize(
         loss, global_step=tf.train.get_or_create_global_step())
     if params["use_tpu"]:
         return tf_estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                  loss=loss,
                                                  train_op=train_op)
     else:
         return tf_estimator.EstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op)
    def _model_fn(features, labels, params, mode=None):
        """Returns tf.estimator.EstimatorSpec."""

        num_output_classes = len(label_vocab)
        predictions, predictions_for_loss = _make_prediction_ops(
            features=features,
            hparams=params,
            mode=mode,
            num_output_classes=num_output_classes)

        evaluation_hooks = []
        if mode == tf_estimator.ModeKeys.TRAIN:
            loss = _make_loss(predictions_for_loss=predictions_for_loss,
                              labels=labels,
                              num_output_classes=num_output_classes)
            train_op = _make_train_op(loss=loss, hparams=params)
            eval_ops = None
        elif mode == tf_estimator.ModeKeys.PREDICT:
            loss = None
            train_op = None
            eval_ops = None
        else:  # Eval mode.
            loss = _make_loss(predictions_for_loss=predictions_for_loss,
                              labels=labels,
                              num_output_classes=num_output_classes)

            train_op = None
            eval_ops = None

        return tf_estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            loss=loss,
            train_op=train_op,
            eval_metric_ops=eval_ops,
            evaluation_hooks=evaluation_hooks,
        )
示例#11
0
  def _gpu_estimator_spec_eval(self, features, logits, labels, loss,
                               losses_dict):
    """Construct EstimatorSpec for GPU EVAL mode."""
    hparams = self.hparams

    if not hasattr(hparams, "problem"):
      raise NotImplementedError(
          "hparams is missing attribute `problem`. NasSeq2Seq must "
          "be used with a problem.")

    # TPU is not supported.
    eval_metrics_fns = metrics.create_evaluation_metrics([hparams.problem],
                                                         hparams)
    eval_metrics = {}
    for metric_name, metric_fn in six.iteritems(eval_metrics_fns):
      if "rouge" not in metric_name and "bleu" not in metric_name:
        eval_metrics[metric_name] = metric_fn(logits, features,
                                              features["targets"])

    return tf_estimator.EstimatorSpec(
        tf_estimator.ModeKeys.EVAL,
        predictions={"predictions": logits},
        eval_metric_ops=eval_metrics,
        loss=loss)
示例#12
0
    def model_fn(features, labels, mode, params):
        """The `model_fn` for tf.Estimator."""
        del labels, params

        if mode != tf_estimator.ModeKeys.PREDICT:
            raise ValueError("Only PREDICT mode is supported: %s" % (mode))

        tf.logging.info("*** Features *** %s %s" % (type(features), features))
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        guid = features["guid"]
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]

        logits = create_model(
            bert_config=bert_config,
            is_training=False,
            fewshot_num_examples_per_class=fewshot_num_examples_per_class,
            input_ids=input_ids,
            input_mask=input_mask,
            segment_ids=segment_ids,
            use_one_hot_embeddings=use_one_hot_embeddings,
            tokenizer=tokenizer,
            class_examples_combiner=class_examples_combiner)

        predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
        output_spec = tf_estimator.EstimatorSpec(mode=mode,
                                                 predictions={
                                                     "predictions":
                                                     predictions,
                                                     "guid": guid,
                                                 })
        return output_spec
示例#13
0
def model_fn(features, labels, mode, params):
    """Model function."""
    del labels

    # [local_batch_size, block_seq_len]
    block_ids = features["block_ids"]
    block_mask = features["block_mask"]
    block_segment_ids = features["block_segment_ids"]

    # [local_batch_size, query_seq_len]
    query_ids = features["query_ids"]
    query_mask = features["query_mask"]

    local_batch_size = tensor_utils.shape(block_ids, 0)
    tf.logging.info("Model batch size: %d", local_batch_size)

    ict_module = create_ict_module(params, mode)

    query_emb = ict_module(inputs=dict(input_ids=query_ids,
                                       input_mask=query_mask,
                                       segment_ids=tf.zeros_like(query_ids)),
                           signature="projected")
    block_emb = ict_module(inputs=dict(input_ids=block_ids,
                                       input_mask=block_mask,
                                       segment_ids=block_segment_ids),
                           signature="projected")

    if params["use_tpu"]:
        # [global_batch_size, hidden_size]
        block_emb = tpu_utils.cross_shard_concat(block_emb)

        # [global_batch_size, local_batch_size]
        labels = tpu_utils.cross_shard_pad(tf.eye(local_batch_size))

        # [local_batch_size]
        labels = tf.argmax(labels, 0)
    else:
        # [local_batch_size]
        labels = tf.range(local_batch_size)

    tf.logging.info("Global batch size: %s", tensor_utils.shape(block_emb, 0))

    # [batch_size, global_batch_size]
    logits = tf.matmul(query_emb, block_emb, transpose_b=True)

    # []
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

    train_op = optimization.create_optimizer(
        loss=loss,
        init_lr=params["learning_rate"],
        num_train_steps=params["num_train_steps"],
        num_warmup_steps=min(10000,
                             max(100, int(params["num_train_steps"] / 10))),
        use_tpu=params["use_tpu"] if "use_tpu" in params else False)

    predictions = tf.argmax(logits, -1)

    metric_args = [
        query_mask, block_mask, labels, predictions, features["mask_query"]
    ]

    def metric_fn(query_mask, block_mask, labels, predictions, mask_query):
        masked_accuracy = tf.metrics.accuracy(labels=labels,
                                              predictions=predictions,
                                              weights=mask_query)
        unmasked_accuracy = tf.metrics.accuracy(
            labels=labels,
            predictions=predictions,
            weights=tf.logical_not(mask_query))
        return dict(query_non_padding=tf.metrics.mean(query_mask),
                    block_non_padding=tf.metrics.mean(block_mask),
                    actual_mask_ratio=tf.metrics.mean(mask_query),
                    masked_accuracy=masked_accuracy,
                    unmasked_accuracy=unmasked_accuracy)

    if params["use_tpu"]:
        return tf_estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                 loss=loss,
                                                 train_op=train_op,
                                                 eval_metrics=(metric_fn,
                                                               metric_args))
    else:
        return tf_estimator.EstimatorSpec(
            mode=mode,
            loss=loss,
            train_op=train_op,
            eval_metric_ops=metric_fn(*metric_args),
            predictions=predictions)
示例#14
0
    def estimator_spec_train(self, loss, num_async_replicas=1, use_tpu=False):
        """Constructs `tf.estimator.EstimatorSpec` for TRAIN (training) mode."""
        train_op = self.optimize(loss,
                                 num_async_replicas=num_async_replicas,
                                 use_tpu=use_tpu)

        sparsity_technique = self._hparams.get("sparsity_technique")
        if "pruning" in sparsity_technique:
            if not self._hparams.load_masks_from:
                # If we are loading trained masks, don't add the mask update
                # step to the training process and keep the masks static
                with tf.control_dependencies([train_op]):
                    mp_hparams = pruning_hparams(
                        self._hparams, use_tpu,
                        sparsity_technique == "random_pruning")
                    p = magnitude_pruning.Pruning(
                        mp_hparams, global_step=tf.train.get_global_step())
                    mask_update_op = p.conditional_mask_update_op()
                    train_op = mask_update_op
            check_global_sparsity()

        if use_tpu:
            if self._hparams.warm_start_from:

                def scaffold_fn():
                    self.initialize_from_ckpt(self._hparams.warm_start_from)
                    return tf.train.Scaffold()
            elif self._hparams.load_masks_from and self._hparams.load_weights_from:

                def scaffold_fn():
                    self.initialize_masks_from_ckpt(
                        self._hparams.load_masks_from)
                    self.initialize_non_masks_from_ckpt(
                        self._hparams.load_weights_from)
                    return tf.train.Scaffold()
            elif self._hparams.load_masks_from:

                def scaffold_fn():
                    self.initialize_masks_from_ckpt(
                        self._hparams.load_masks_from)
                    return tf.train.Scaffold()
            else:
                scaffold_fn = None

            # Note: important to call this before remove_summaries()
            if self.hparams.tpu_enable_host_call:
                host_call = t2t_model.create_host_call(self.hparams.model_dir)
            else:
                host_call = None

            t2t_model.remove_summaries()

            return contrib_tpu.TPUEstimatorSpec(tf_estimator.ModeKeys.TRAIN,
                                                loss=loss,
                                                train_op=train_op,
                                                host_call=host_call,
                                                scaffold_fn=scaffold_fn)
        else:
            if self._hparams.warm_start_from:
                self.initialize_from_ckpt(self._hparams.warm_start_from)
            elif self._hparams.load_masks_from:
                self.initialize_masks_from_ckpt(self._hparams.load_masks_from)

            return tf_estimator.EstimatorSpec(tf_estimator.ModeKeys.TRAIN,
                                              loss=loss,
                                              train_op=train_op)
示例#15
0
    def _model_fn(features, labels, mode, params, config):
        """Defines an `Estimator` `model_fn`."""
        del [config, params]

        # In Estimator, all sub-graphs need to be constructed inside the model_fn.
        # Hence, ranker, losses, metrics and optimizer are cloned inside this
        # function.
        ranker = tf.keras.models.clone_model(model, clone_function=_clone_fn)
        training = (mode == tf_estimator.ModeKeys.TRAIN)

        weights = None
        if weights_feature_name and mode != tf_estimator.ModeKeys.PREDICT:
            if weights_feature_name not in features:
                raise ValueError(
                    "weights_feature '{0}' can not be found in 'features'.".
                    format(weights_feature_name))
            else:
                weights = utils.reshape_to_2d(
                    features.pop(weights_feature_name))

        logits = ranker(features, training=training)

        if serving_default not in ["regress", "predict"]:
            raise ValueError(
                "serving_default should be 'regress' or 'predict', "
                "but got {}".format(serving_default))

        if serving_default == "regress":
            default_export_output = tf_estimator.export.RegressionOutput(
                logits)
        else:
            default_export_output = tf_estimator.export.PredictOutput(logits)
        export_outputs = {
            _DEFAULT_SERVING_KEY: default_export_output,
            _REGRESS_SERVING_KEY: tf_estimator.export.RegressionOutput(logits),
            _PREDICT_SERVING_KEY: tf_estimator.export.PredictOutput(logits)
        }

        if mode == tf_estimator.ModeKeys.PREDICT:
            return tf_estimator.EstimatorSpec(mode=mode,
                                              predictions=logits,
                                              export_outputs=export_outputs)

        loss = _clone_fn(model.loss)
        total_loss = loss(labels, logits, sample_weight=weights)

        keras_metrics = []
        for metric in model.metrics:
            keras_metrics.append(_clone_fn(metric))
        # Adding default metrics here as model.metrics does not contain custom
        # metrics.
        keras_metrics += metrics.default_keras_metrics()
        eval_metric_ops = {}
        for keras_metric in keras_metrics:
            keras_metric.update_state(labels, logits, sample_weight=weights)
            eval_metric_ops[keras_metric.name] = keras_metric

        train_op = None
        if training:
            optimizer = _clone_fn(model.optimizer)
            optimizer.iterations = tf.compat.v1.train.get_or_create_global_step(
            )
            # Get both the unconditional updates (the None part)
            # and the input-conditional updates (the features part).
            # These updates are for layers like BatchNormalization, which have
            # separate update and minimize ops.
            update_ops = ranker.get_updates_for(None) + ranker.get_updates_for(
                features)
            minimize_op = optimizer.get_updates(
                loss=total_loss, params=ranker.trainable_variables)[0]
            train_op = tf.group(minimize_op, *update_ops)

        return tf_estimator.EstimatorSpec(mode=mode,
                                          predictions=logits,
                                          loss=total_loss,
                                          train_op=train_op,
                                          eval_metric_ops=eval_metric_ops,
                                          export_outputs=export_outputs)
示例#16
0
def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
  tf.logging.info("features = %s labels = %s mode = %s params=%s" %
                  (features, labels, mode, params))
  global_step = tf.train.get_global_step()
  graph = mtf.Graph()
  mesh = mtf.Mesh(graph, "my_mesh")
  logits, loss = mnist_model(features, labels, mesh)
  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
  mesh_size = mesh_shape.size
  mesh_devices = [""] * mesh_size
  mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
      mesh_shape, layout_rules, mesh_devices)

  if mode == tf_estimator.ModeKeys.TRAIN:
    var_grads = mtf.gradients(
        [loss], [v.outputs[0] for v in graph.trainable_variables])
    optimizer = mtf.optimize.AdafactorOptimizer()
    update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)

  lowering = mtf.Lowering(graph, {mesh: mesh_impl})
  restore_hook = mtf.MtfRestoreHook(lowering)

  tf_logits = lowering.export_to_tf_tensor(logits)
  if mode != tf_estimator.ModeKeys.PREDICT:
    tf_loss = lowering.export_to_tf_tensor(loss)
    tf.summary.scalar("loss", tf_loss)

  if mode == tf_estimator.ModeKeys.TRAIN:
    tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
    tf_update_ops.append(tf.assign_add(global_step, 1))
    train_op = tf.group(tf_update_ops)
    saver = tf.train.Saver(
        tf.global_variables(),
        sharded=True,
        max_to_keep=10,
        keep_checkpoint_every_n_hours=2,
        defer_build=False, save_relative_paths=True)
    tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
    saver_listener = mtf.MtfCheckpointSaverListener(lowering)
    saver_hook = tf.train.CheckpointSaverHook(
        FLAGS.model_dir,
        save_steps=1000,
        saver=saver,
        listeners=[saver_listener])

    accuracy = tf.metrics.accuracy(
        labels=labels, predictions=tf.argmax(tf_logits, axis=1))

    # Name tensors to be logged with LoggingTensorHook.
    tf.identity(tf_loss, "cross_entropy")
    tf.identity(accuracy[1], name="train_accuracy")

    # Save accuracy scalar to Tensorboard output.
    tf.summary.scalar("train_accuracy", accuracy[1])

    # restore_hook must come before saver_hook
    return tf_estimator.EstimatorSpec(
        tf_estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
        training_chief_hooks=[restore_hook, saver_hook])

  if mode == tf_estimator.ModeKeys.PREDICT:
    predictions = {
        "classes": tf.argmax(tf_logits, axis=1),
        "probabilities": tf.nn.softmax(tf_logits),
    }
    return tf_estimator.EstimatorSpec(
        mode=tf_estimator.ModeKeys.PREDICT,
        predictions=predictions,
        prediction_hooks=[restore_hook],
        export_outputs={
            "classify": tf_estimator.export.PredictOutput(predictions)
        })
  if mode == tf_estimator.ModeKeys.EVAL:
    return tf_estimator.EstimatorSpec(
        mode=tf_estimator.ModeKeys.EVAL,
        loss=tf_loss,
        evaluation_hooks=[restore_hook],
        eval_metric_ops={
            "accuracy":
            tf.metrics.accuracy(
                labels=labels, predictions=tf.argmax(tf_logits, axis=1)),
        })
示例#17
0
def model_fn(features, labels, mode, params, config):
    """Builds the acoustic model."""
    del config
    hparams = params

    length = features.length
    spec = features.spec

    is_training = mode == tf_estimator.ModeKeys.TRAIN

    if is_training:
        onset_labels = labels.onsets
        offset_labels = labels.offsets
        velocity_labels = labels.velocities
        frame_labels = labels.labels
        frame_label_weights = labels.label_weights

    if hparams.stop_activation_gradient and not hparams.activation_loss:
        raise ValueError(
            'If stop_activation_gradient is true, activation_loss must be true.'
        )

    losses = {}
    with slim.arg_scope([slim.batch_norm, slim.dropout],
                        is_training=is_training):
        with tf.variable_scope('onsets'):
            onset_outputs = acoustic_model(spec,
                                           hparams,
                                           lstm_units=hparams.onset_lstm_units,
                                           lengths=length)
            onset_probs = slim.fully_connected(onset_outputs,
                                               constants.MIDI_PITCHES,
                                               activation_fn=tf.sigmoid,
                                               scope='onset_probs')

            # onset_probs_flat is used during inference.
            onset_probs_flat = flatten_maybe_padded_sequences(
                onset_probs, length)
            if is_training:
                onset_labels_flat = flatten_maybe_padded_sequences(
                    onset_labels, length)
                onset_losses = tf_utils.log_loss(onset_labels_flat,
                                                 onset_probs_flat)
                tf.losses.add_loss(tf.reduce_mean(onset_losses))
                losses['onset'] = onset_losses
        with tf.variable_scope('offsets'):
            offset_outputs = acoustic_model(
                spec,
                hparams,
                lstm_units=hparams.offset_lstm_units,
                lengths=length)
            offset_probs = slim.fully_connected(offset_outputs,
                                                constants.MIDI_PITCHES,
                                                activation_fn=tf.sigmoid,
                                                scope='offset_probs')

            # offset_probs_flat is used during inference.
            offset_probs_flat = flatten_maybe_padded_sequences(
                offset_probs, length)
            if is_training:
                offset_labels_flat = flatten_maybe_padded_sequences(
                    offset_labels, length)
                offset_losses = tf_utils.log_loss(offset_labels_flat,
                                                  offset_probs_flat)
                tf.losses.add_loss(tf.reduce_mean(offset_losses))
                losses['offset'] = offset_losses
        with tf.variable_scope('velocity'):
            velocity_outputs = acoustic_model(
                spec,
                hparams,
                lstm_units=hparams.velocity_lstm_units,
                lengths=length)
            velocity_values = slim.fully_connected(velocity_outputs,
                                                   constants.MIDI_PITCHES,
                                                   activation_fn=None,
                                                   scope='onset_velocities')

            velocity_values_flat = flatten_maybe_padded_sequences(
                velocity_values, length)
            if is_training:
                velocity_labels_flat = flatten_maybe_padded_sequences(
                    velocity_labels, length)
                velocity_loss = tf.reduce_sum(
                    onset_labels_flat *
                    tf.square(velocity_labels_flat - velocity_values_flat),
                    axis=1)
                tf.losses.add_loss(tf.reduce_mean(velocity_loss))
                losses['velocity'] = velocity_loss

        with tf.variable_scope('frame'):
            if not hparams.share_conv_features:
                # TODO(eriche): this is broken when hparams.frame_lstm_units > 0
                activation_outputs = acoustic_model(
                    spec,
                    hparams,
                    lstm_units=hparams.frame_lstm_units,
                    lengths=length)
                activation_probs = slim.fully_connected(
                    activation_outputs,
                    constants.MIDI_PITCHES,
                    activation_fn=tf.sigmoid,
                    scope='activation_probs')
            else:
                activation_probs = slim.fully_connected(
                    onset_outputs,
                    constants.MIDI_PITCHES,
                    activation_fn=tf.sigmoid,
                    scope='activation_probs')

            probs = []
            if hparams.stop_onset_gradient:
                probs.append(tf.stop_gradient(onset_probs))
            else:
                probs.append(onset_probs)

            if hparams.stop_activation_gradient:
                probs.append(tf.stop_gradient(activation_probs))
            else:
                probs.append(activation_probs)

            if hparams.stop_offset_gradient:
                probs.append(tf.stop_gradient(offset_probs))
            else:
                probs.append(offset_probs)

            combined_probs = tf.concat(probs, 2)

            if hparams.combined_lstm_units > 0:
                outputs = lstm_layer(
                    combined_probs,
                    hparams.combined_lstm_units,
                    lengths=length if hparams.use_lengths else None,
                    stack_size=hparams.combined_rnn_stack_size,
                    use_cudnn=hparams.use_cudnn,
                    bidirectional=hparams.bidirectional)
            else:
                outputs = combined_probs

            frame_probs = slim.fully_connected(outputs,
                                               constants.MIDI_PITCHES,
                                               activation_fn=tf.sigmoid,
                                               scope='frame_probs')

        frame_probs_flat = flatten_maybe_padded_sequences(frame_probs, length)

        if is_training:
            frame_labels_flat = flatten_maybe_padded_sequences(
                frame_labels, length)
            frame_label_weights_flat = flatten_maybe_padded_sequences(
                frame_label_weights, length)
            if hparams.weight_frame_and_activation_loss:
                frame_loss_weights = frame_label_weights_flat
            else:
                frame_loss_weights = None
            frame_losses = tf_utils.log_loss(frame_labels_flat,
                                             frame_probs_flat,
                                             weights=frame_loss_weights)
            tf.losses.add_loss(tf.reduce_mean(frame_losses))
            losses['frame'] = frame_losses

            if hparams.activation_loss:
                if hparams.weight_frame_and_activation_loss:
                    activation_loss_weights = frame_label_weights
                else:
                    activation_loss_weights = None
                activation_losses = tf_utils.log_loss(
                    frame_labels_flat,
                    flatten_maybe_padded_sequences(activation_probs, length),
                    weights=activation_loss_weights)
                tf.losses.add_loss(tf.reduce_mean(activation_losses))
                losses['activation'] = activation_losses

    frame_predictions = frame_probs_flat > hparams.predict_frame_threshold
    onset_predictions = onset_probs_flat > hparams.predict_onset_threshold
    offset_predictions = offset_probs_flat > hparams.predict_offset_threshold

    frame_predictions = tf.expand_dims(frame_predictions, axis=0)
    onset_predictions = tf.expand_dims(onset_predictions, axis=0)
    offset_predictions = tf.expand_dims(offset_predictions, axis=0)
    velocity_values = tf.expand_dims(velocity_values_flat, axis=0)

    metrics_values = metrics.define_metrics(
        frame_probs=frame_probs,
        onset_probs=onset_probs,
        frame_predictions=frame_predictions,
        onset_predictions=onset_predictions,
        offset_predictions=offset_predictions,
        velocity_values=velocity_values,
        length=features.length,
        sequence_label=labels.note_sequence,
        frame_labels=labels.labels,
        sequence_id=features.sequence_id,
        hparams=hparams)

    for label, loss_collection in losses.items():
        loss_label = 'losses/' + label
        metrics_values[loss_label] = loss_collection

    def predict_sequence():
        """Convert frame predictions into a sequence (TF)."""
        def _predict(frame_probs, onset_probs, frame_predictions,
                     onset_predictions, offset_predictions, velocity_values):
            """Convert frame predictions into a sequence (Python)."""
            sequence = infer_util.predict_sequence(
                frame_probs=frame_probs,
                onset_probs=onset_probs,
                frame_predictions=frame_predictions,
                onset_predictions=onset_predictions,
                offset_predictions=offset_predictions,
                velocity_values=velocity_values,
                hparams=hparams,
                min_pitch=constants.MIN_MIDI_PITCH)
            return sequence.SerializeToString()

        sequence = tf.py_func(_predict,
                              inp=[
                                  frame_probs[0],
                                  onset_probs[0],
                                  frame_predictions[0],
                                  onset_predictions[0],
                                  offset_predictions[0],
                                  velocity_values[0],
                              ],
                              Tout=tf.string,
                              stateful=False)
        sequence.set_shape([])
        return tf.expand_dims(sequence, axis=0)

    predictions = {
        'frame_probs': frame_probs,
        'onset_probs': onset_probs,
        'frame_predictions': frame_predictions,
        'onset_predictions': onset_predictions,
        'offset_predictions': offset_predictions,
        'velocity_values': velocity_values,
        'sequence_predictions': predict_sequence(),
        # Include some features and labels in output because Estimator 'predict'
        # API does not give access to them.
        'sequence_ids': features.sequence_id,
        'sequence_labels': labels.note_sequence,
        'frame_labels': labels.labels,
        'onset_labels': labels.onsets,
    }
    for k, v in metrics_values.items():
        predictions[k] = tf.stack(v)

    metric_ops = {k: tf.metrics.mean(v) for k, v in metrics_values.items()}

    train_op = None
    loss = None
    if is_training:
        # Creates a pianoroll labels in red and probs in green [minibatch, 88]
        images = {}
        onset_pianorolls = tf.concat([
            onset_labels[:, :, :, tf.newaxis], onset_probs[:, :, :,
                                                           tf.newaxis],
            tf.zeros(tf.shape(onset_labels))[:, :, :, tf.newaxis]
        ],
                                     axis=3)
        images['OnsetPianorolls'] = onset_pianorolls
        offset_pianorolls = tf.concat([
            offset_labels[:, :, :, tf.newaxis], offset_probs[:, :, :,
                                                             tf.newaxis],
            tf.zeros(tf.shape(offset_labels))[:, :, :, tf.newaxis]
        ],
                                      axis=3)
        images['OffsetPianorolls'] = offset_pianorolls
        activation_pianorolls = tf.concat([
            frame_labels[:, :, :, tf.newaxis], frame_probs[:, :, :,
                                                           tf.newaxis],
            tf.zeros(tf.shape(frame_labels))[:, :, :, tf.newaxis]
        ],
                                          axis=3)
        images['ActivationPianorolls'] = activation_pianorolls
        for name, image in images.items():
            tf.summary.image(name, image)

        loss = tf.losses.get_total_loss()
        tf.summary.scalar('loss', loss)
        for label, loss_collection in losses.items():
            loss_label = 'losses/' + label
            tf.summary.scalar(loss_label, tf.reduce_mean(loss_collection))

        train_op = slim.optimize_loss(
            name='training',
            loss=loss,
            global_step=tf.train.get_or_create_global_step(),
            learning_rate=hparams.learning_rate,
            learning_rate_decay_fn=functools.partial(
                tf.train.exponential_decay,
                decay_steps=hparams.decay_steps,
                decay_rate=hparams.decay_rate,
                staircase=True),
            clip_gradients=hparams.clip_norm,
            optimizer='Adam')

    return tf_estimator.EstimatorSpec(mode=mode,
                                      predictions=predictions,
                                      loss=loss,
                                      train_op=train_op,
                                      eval_metric_ops=metric_ops)
示例#18
0
def model_fn(features, labels, mode, params):
    """Model function."""
    if labels is None:
        labels = tf.constant([""])

    reader_beam_size = params["reader_beam_size"]
    if mode == tf_estimator.ModeKeys.PREDICT:
        retriever_beam_size = reader_beam_size
    else:
        retriever_beam_size = params["retriever_beam_size"]
    assert reader_beam_size <= retriever_beam_size

    with tf.device("/cpu:0"):
        retriever_outputs = retrieve(features=features,
                                     retriever_beam_size=retriever_beam_size,
                                     mode=mode,
                                     params=params)

    with tf.variable_scope("reader"):
        reader_outputs = read(
            features=features,
            retriever_logits=retriever_outputs.logits[:reader_beam_size],
            blocks=retriever_outputs.blocks[:reader_beam_size],
            mode=mode,
            params=params,
            labels=labels)

    predictions = get_predictions(reader_outputs, params)

    if mode == tf_estimator.ModeKeys.PREDICT:
        loss = None
        train_op = None
        eval_metric_ops = None
    else:
        # [retriever_beam_size]
        retriever_correct = orqa_ops.has_answer(
            blocks=retriever_outputs.blocks, answers=labels)

        # [reader_beam_size, num_candidates]
        reader_correct = compute_correct_candidates(
            candidate_starts=reader_outputs.candidate_starts,
            candidate_ends=reader_outputs.candidate_ends,
            gold_starts=reader_outputs.gold_starts,
            gold_ends=reader_outputs.gold_ends)

        eval_metric_ops = compute_eval_metrics(
            labels=labels,
            predictions=predictions,
            retriever_correct=retriever_correct,
            reader_correct=reader_correct)

        # []
        loss = compute_loss(retriever_logits=retriever_outputs.logits,
                            retriever_correct=retriever_correct,
                            reader_logits=reader_outputs.logits,
                            reader_correct=reader_correct)

        train_op = optimization.create_optimizer(
            loss=loss,
            init_lr=params["learning_rate"],
            num_train_steps=params["num_train_steps"],
            num_warmup_steps=min(10000,
                                 max(100,
                                     int(params["num_train_steps"] / 10))),
            use_tpu=False)

    return tf_estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      train_op=train_op,
                                      predictions=predictions,
                                      eval_metric_ops=eval_metric_ops)
示例#19
0
def resnet_model_fn(features, labels, mode, params):
  """The model_fn for ResNet to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images. If transpose_input is enabled, it is
      transposed to device layout and reshaped to 1D tensor.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
      `params['batch_size']` is always provided and should be used as the
      effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
  if isinstance(features, dict):
    features = features['feature']

  # In most cases, the default data format NCHW instead of NHWC should be
  # used for a significant performance boost on GPU/TPU. NHWC should be used
  # only if the network needs to be run on CPU since the pooling operations
  # are only supported on NHWC.
  if params['data_format'] == 'channels_first':
    assert not params['transpose_input']  # channels_first only for GPU
    features = tf.transpose(features, [0, 3, 1, 2])

  if params['transpose_input'] and mode != tf_estimator.ModeKeys.PREDICT:
    image_size = tf.sqrt(tf.shape(features)[0] / (3 * tf.shape(labels)[0]))
    features = tf.reshape(features, [image_size, image_size, 3, -1])
    features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

  # Normalize the image to zero mean and unit variance.
  features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
  features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)

  # DropBlock keep_prob for the 4 block groups of ResNet architecture.
  # None means applying no DropBlock at the corresponding block group.
  dropblock_keep_probs = [None] * 4
  if params['dropblock_groups']:
    # Scheduled keep_prob for DropBlock.
    train_steps = tf.cast(params['train_steps'], tf.float32)
    current_step = tf.cast(tf.train.get_global_step(), tf.float32)
    current_ratio = current_step / train_steps
    dropblock_keep_prob = (1 - current_ratio *
                           (1 - params['dropblock_keep_prob']))

    # Computes DropBlock keep_prob for different block groups of ResNet.
    dropblock_groups = [int(x) for x in params['dropblock_groups'].split(',')]
    for block_group in dropblock_groups:
      if block_group < 1 or block_group > 4:
        raise ValueError(
            'dropblock_groups should be a comma separated list of integers '
            'between 1 and 4 (dropblcok_groups: {}).'.format(
                params['dropblock_groups']))
      dropblock_keep_probs[block_group - 1] = 1 - (
          (1 - dropblock_keep_prob) / 4.0**(4 - block_group))

  # This nested function allows us to avoid duplicating the logic which
  # builds the network, for different values of --precision.
  def build_network():
    network = resnet_model.resnet_v1(
        resnet_depth=params['resnet_depth'],
        num_classes=params['num_label_classes'],
        dropblock_size=params['dropblock_size'],
        dropblock_keep_probs=dropblock_keep_probs,
        data_format=params['data_format'])
    return network(
        inputs=features, is_training=(mode == tf_estimator.ModeKeys.TRAIN))

  if params['precision'] == 'bfloat16':
    with contrib_tpu.bfloat16_scope():
      logits = build_network()
    logits = tf.cast(logits, tf.float32)
  elif params['precision'] == 'float32':
    logits = build_network()

  if mode == tf_estimator.ModeKeys.PREDICT:
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
    }
    return tf_estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        export_outputs={
            'classify': tf_estimator.export.PredictOutput(predictions)
        })

  # If necessary, in the model_fn, use params['batch_size'] instead the batch
  # size flags (--train_batch_size or --eval_batch_size).
  batch_size = params['batch_size']  # pylint: disable=unused-variable

  # Calculate loss, which includes softmax cross entropy and L2 regularization.
  one_hot_labels = tf.one_hot(labels, params['num_label_classes'])
  cross_entropy = tf.losses.softmax_cross_entropy(
      logits=logits,
      onehot_labels=one_hot_labels,
      label_smoothing=params['label_smoothing'])

  # Add weight decay to the loss for non-batch-normalization variables.
  loss = cross_entropy + params['weight_decay'] * tf.add_n([
      tf.nn.l2_loss(v)
      for v in tf.trainable_variables()
      if 'batch_normalization' not in v.name
  ])

  host_call = None
  if mode == tf_estimator.ModeKeys.TRAIN:
    # Compute the current epoch and associated learning rate from global_step.
    global_step = tf.train.get_global_step()
    steps_per_epoch = params['num_train_images'] / params['train_batch_size']
    current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch)
    # LARS is a large batch optimizer. LARS enables higher accuracy at batch 16K
    # and larger batch sizes.
    if params['enable_lars']:
      learning_rate = 0.0
      optimizer = lars_util.init_lars_optimizer(current_epoch, params)
      raise ValueError('LARS unexpected in the context of IGT experiments.')
    else:
      learning_rate = linear_learning_rate_schedule(params, global_step)

      if FLAGS.optimizer == 'momentum':
        tf.logging.info('Using MomentumOptimizer ({}).'.format(
            params['momentum']))
        optimizer = tf.train.MomentumOptimizer(
            learning_rate=learning_rate,
            momentum=params['momentum'],
            use_nesterov=False)

      elif FLAGS.optimizer == 'adam':
        tf.logging.info('Using AdamOptimizer')
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

      elif FLAGS.optimizer == 'eigt':
        tf.logging.info('Using ExpIgtOptimizer {} tail: {}'.format(
            FLAGS.igt_optimizer, FLAGS.tail_fraction))
        optimizer = exp_igt_optimizer.ExpIgtOptimizer(
            learning_rate,
            tail_fraction=FLAGS.tail_fraction,
            optimizer=FLAGS.igt_optimizer)

      else:
        raise ValueError('{} is not a supported optimizer'.format(
            FLAGS.optimizer))

    if params['use_tpu']:
      # When using TPU, wrap the optimizer with CrossShardOptimizer which
      # handles synchronization details between different TPU cores. To the
      # user, this should look like regular synchronous training.
      optimizer = contrib_tpu.CrossShardOptimizer(optimizer)

    # Batch normalization requires UPDATE_OPS to be added as a dependency to
    # the train operation.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
      train_op = optimizer.minimize(loss, global_step)

    if not params['skip_host_call']:

      def host_call_fn(gs, loss, lr, ce):
        """Training host call.

        Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          gs: `Tensor with shape `[batch]` for the global_step
          loss: `Tensor` with shape `[batch]` for the training loss.
          lr: `Tensor` with shape `[batch]` for the learning_rate.
          ce: `Tensor` with shape `[batch]` for the current_epoch.

        Returns:
          List of summary ops to run on the CPU host.
        """
        gs = gs[0]
        # Host call fns are executed params['iterations_per_loop'] times after
        # one TPU loop is finished, setting max_queue value to the same as
        # number of iterations will make the summary writer only flush the data
        # to storage once per loop.
        with summary.create_file_writer(
            get_model_dir(params),
            max_queue=params['iterations_per_loop']).as_default():
          with summary.always_record_summaries():
            summary.scalar('loss', loss[0], step=gs)
            summary.scalar('learning_rate', lr[0], step=gs)
            summary.scalar('current_epoch', ce[0], step=gs)

            return summary.all_summary_ops()

      # To log the loss, current learning rate, and epoch for Tensorboard, the
      # summary op needs to be run on the host CPU via host_call. host_call
      # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
      # dimension. These Tensors are implicitly concatenated to
      # [params['batch_size']].
      gs_t = tf.reshape(global_step, [1])
      loss_t = tf.reshape(loss, [1])
      lr_t = tf.reshape(learning_rate, [1])
      ce_t = tf.reshape(current_epoch, [1])

      host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t])

  else:
    train_op = None

  eval_metrics = None
  scaffold_fn = None
  if mode == tf_estimator.ModeKeys.EVAL:

    def metric_fn(labels, logits):
      """Evaluation metric function.

      Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
      predictions = tf.argmax(logits, axis=1)
      top_1_accuracy = tf.metrics.accuracy(labels, predictions)
      in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
      top_5_accuracy = tf.metrics.mean(in_top_5)

      return {
          'top_1_accuracy': top_1_accuracy,
          'top_5_accuracy': top_5_accuracy,
      }

    eval_metrics = (metric_fn, [labels, logits])

    if FLAGS.mode == 'eval_igt' and FLAGS.igt_eval_mode == 'true':
      tf.logging.info('Using true param loading saver.')

      def scaffold_fn_true_params():
        """Returns a scaffold that loads the true values into vars."""
        var_mapping = {}
        trainable_vars = set(tf.trainable_variables())
        for var in tf.global_variables():
          if var in trainable_vars:
            var_mapping[var.op.name + '/true_param'] = var
          else:
            var_mapping[var.op.name] = var

        tf.logging.info('Mapping: {}'.format(var_mapping))
        saver = tf.train.Saver(var_list=var_mapping, sharded=True)
        return tf.train.Scaffold(saver=saver)

      scaffold_fn = scaffold_fn_true_params

  return contrib_tpu.TPUEstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      host_call=host_call,
      eval_metrics=eval_metrics,
      scaffold_fn=scaffold_fn)
示例#20
0
def model_fn(features, labels, mode, params, grammar):
    """Builds the model graph.

  Args:
    features: Dict of tensors.
    labels: Dict of tensors, or None if mode == INFER.
    mode: tf.estimator.ModeKeys execution mode.
    params: HParams object containing model hyperparameters.
    grammar: arithmetic_grammar.Grammar object.

  Returns:
    A ModelFnOps object defining predictions, loss, and train_op.
  """
    if mode != tf_estimator.ModeKeys.PREDICT:
        tf.summary.text('expression_string',
                        features['expression_string'][:10])
    tf.summary.text('production_rules',
                    tf.constant(grammar.grammar_to_string()))

    # Make features easier to look up.
    with tf.variable_scope('features'):
        features = {
            key: tf.identity(value, name=key)
            for key, value in six.iteritems(features)
        }

    embedding_layer = networks.partial_sequence_encoder(
        features=features,
        symbolic_properties=core.hparams_list_value(
            params.symbolic_properties),
        numerical_points=core.hparams_list_value(params.numerical_points),
        num_production_rules=grammar.num_production_rules,
        embedding_size=params.embedding_size)

    logits = networks.build_stacked_gru_model(
        embedding_layer=embedding_layer,
        partial_sequence_length=features['partial_sequence_length'],
        gru_hidden_sizes=params.gru_hidden_sizes,
        num_output_features=grammar.num_production_rules,
        bidirectional=params.bidirectional)

    predictions = {'logits': tf.identity(logits, name='predictions/logits')}
    predictions.update({
        name: tf.identity(tensor, name='predictions/%s' % name)
        for name, tensor in six.iteritems(
            mask_logits(logits, features['next_production_rule_mask']))
    })
    predictions['next_production_rule'] = tf.argmax(
        predictions['masked_probabilities'],
        axis=1,
        name='predictions/next_production_rule')

    if mode == tf_estimator.ModeKeys.PREDICT:
        return tf_estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # NOTE(leeley): The mask cannot be applied directly on logits. Because 0
    # logit is still corresponding to a positive probability. Since
    # tf.losses.sparse_softmax_cross_entropy() only works for logits rather than
    # probabilities, I convert probabilities back to logits by tf.log(). Since
    # the probabilities for grammarly invalid production rules are 0, to avoid
    # numerical issue of log(0), I added a small number 1e-10.
    loss = tf.losses.sparse_softmax_cross_entropy(
        labels, tf.log(predictions['masked_probabilities'] + 1e-10))

    # Configure the training op for TRAIN mode.
    if mode == tf_estimator.ModeKeys.TRAIN:
        train_op = contrib_layers.optimize_loss(
            loss=loss,
            global_step=tf.train.get_global_step(),
            learning_rate=core.learning_rate_decay(
                initial_learning_rate=params.learning_rate,
                decay_steps=params.learning_rate_decay_steps,
                decay_rate=params.learning_rate_decay_rate),
            optimizer=params.optimizer,
            summaries=contrib_layers.OPTIMIZER_SUMMARIES)
        return tf_estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op)

    # Add evaluation metrics for EVAL mode.
    eval_metric_ops = {
        'eval_loss':
        tf.metrics.mean(loss),
        'count':
        contrib_metrics.count(labels),
        'next_production_rule_valid_ratio':
        metrics.next_production_rule_valid_ratio(
            unmasked_probabilities_batch=predictions['unmasked_probabilities'],
            next_production_rule_masks=features['next_production_rule_mask']),
        'next_production_rule_accuracy':
        metrics.next_production_rule_accuracy(
            next_production_rules=labels,
            predict_next_production_rules=predictions['next_production_rule']),
    }

    for target_length in range(1, params.max_length + 1):
        eval_metric_ops[
            'next_production_rule_info/length_%d' %
            target_length] = metrics.next_production_rule_info_batch_text_summary(
                expression_strings=features['expression_string'],
                partial_sequences=features['partial_sequence'],
                partial_sequence_lengths=features['partial_sequence_length'],
                next_production_rules=labels,
                unmasked_probabilities_batch=predictions[
                    'unmasked_probabilities'],
                masked_probabilities_batch=predictions['masked_probabilities'],
                grammar=grammar,
                target_length=target_length)

        eval_metric_ops[
            'next_production_rule_valid_ratio/length_%d' %
            target_length] = metrics.next_production_rule_valid_ratio(
                unmasked_probabilities_batch=predictions[
                    'unmasked_probabilities'],
                next_production_rule_masks=features[
                    'next_production_rule_mask'],
                partial_sequence_lengths=features['partial_sequence_length'],
                target_length=target_length)

        eval_metric_ops[
            'next_production_rule_accuracy/length_%d' %
            target_length] = metrics.next_production_rule_accuracy(
                next_production_rules=labels,
                predict_next_production_rules=predictions[
                    'next_production_rule'],
                partial_sequence_lengths=features['partial_sequence_length'],
                target_length=target_length)

    if params.num_expressions_per_condition > 0:
        with tf.variable_scope('conditional_generation'):
            match_ratio = tf.placeholder(tf.float32,
                                         shape=[None],
                                         name='match_ratio')
            fail_ratio = tf.placeholder(tf.float32,
                                        shape=[None],
                                        name='fail_ratio')

        eval_metric_ops.update({
            'generation_match_ratio':
            tf.metrics.mean(match_ratio),
            'generation_fail_ratio':
            tf.metrics.mean(fail_ratio),
        })

    return tf_estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      eval_metric_ops=eval_metric_ops)
示例#21
0
def resnet_model_fn(features, labels, mode, params):
    """Setup of training and eval for modified dataset using a ResNet-50.

  Args:
    features: A float32 batch of images.
    labels: A int32 batch of labels.
    mode: Specifies whether training or evaluation.
    params: Dictionary of parameters passed to the model.

  Returns:
    Model estimator w specifications.
  """

    if isinstance(features, dict):
        features = features['feature']

    mean_rgb = params['mean_rgb']
    stddev_rgb = params['stddev_rgb']
    features -= tf.constant(mean_rgb, shape=[1, 1, 3], dtype=features.dtype)
    features /= tf.constant(stddev_rgb, shape=[1, 1, 3], dtype=features.dtype)

    train_batch_size = params['train_batch_size']

    steps_per_epoch = params['num_train_images'] / train_batch_size
    initial_learning_rate = params['base_learning_rate']
    num_label_classes = params['num_label_classes']

    network = resnet_model.resnet_50(num_classes=num_label_classes,
                                     data_format=params['data_format'])

    logits = network(inputs=features,
                     is_training=(mode == tf_estimator.ModeKeys.TRAIN))

    output_dir = params['output_dir']
    weight_decay = params['weight_decay']

    one_hot_labels = tf.one_hot(labels, num_label_classes)
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits, onehot_labels=one_hot_labels, label_smoothing=0.1)

    loss = cross_entropy + weight_decay * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])
    host_call = None
    if mode == tf_estimator.ModeKeys.TRAIN:

        global_step = tf.train.get_global_step()

        steps_per_epoch = params['num_train_images'] / train_batch_size
        current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch)
        learning_rate = compute_lr(current_epoch, initial_learning_rate,
                                   train_batch_size, params['lr_schedule'])
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=params['momentum'],
                                               use_nesterov=True)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops), tf.name_scope('train'):
            train_op = optimizer.minimize(loss, global_step)

        with tf2.summary.create_file_writer(output_dir).as_default():
            with tf2.summary.record_if(True):
                tf2.summary.scalar('loss', loss, step=global_step)
                tf2.summary.scalar('learning_rate',
                                   learning_rate,
                                   step=global_step)
                tf2.summary.scalar('current_epoch',
                                   current_epoch,
                                   step=global_step)
                tf2.summary.scalar('steps_per_epoch',
                                   steps_per_epoch,
                                   step=global_step)
                tf2.summary.scalar('weight_decay',
                                   weight_decay,
                                   step=global_step)

            tf.summary.all_v2_summary_ops()

    else:
        train_op = None

    eval_metrics = {}
    if mode == tf_estimator.ModeKeys.EVAL:
        train_op = None
        predictions = tf.argmax(logits, axis=1)
        eval_metrics['top_1_accuracy'] = tf.metrics.accuracy(
            labels, predictions)
        in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
        eval_metrics['top_5_accuracy'] = tf.metrics.mean(in_top_5)

    return tf_estimator.EstimatorSpec(training_hooks=host_call,
                                      mode=mode,
                                      loss=loss,
                                      train_op=train_op,
                                      eval_metric_ops=eval_metrics)
示例#22
0
    def estimator_model_fn(cls,
                           hparams,
                           features,
                           labels,
                           mode,
                           config=None,
                           params=None,
                           decode_hparams=None,
                           use_tpu=False):
        hparams = hparams_lib.copy_hparams(hparams)
        hparams.use_tpu = use_tpu
        # merge decode_hparams into hparams if present
        if mode == tf_estimator.ModeKeys.PREDICT and decode_hparams is not None:
            for k, v in six.iteritems(decode_hparams.values()):
                if hasattr(hparams, k) and getattr(hparams, k) != v:
                    tf.logging.warning(
                        "Overriding hparams.%s with %s from decode_hparams" %
                        (k, v))
                setattr(hparams, k, v)

        # Instantiate model
        data_parallelism = None
        if not use_tpu and config:
            data_parallelism = config.data_parallelism
        model = cls(hparams,
                    mode,
                    data_parallelism=data_parallelism,
                    decode_hparams=decode_hparams)

        global_step = tf.train.get_global_step()

        mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(hparams.layout)
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
        else:
            var_placer = None
            if data_parallelism is None or len(
                    data_parallelism.ps_devices) == 1:
                mesh_devices = [""] * mesh_shape.size
            else:
                assert len(data_parallelism.ps_devices) == mesh_shape.size
                mesh_devices = data_parallelism.ps_devices
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)
        # PREDICT mode
        if mode == tf_estimator.ModeKeys.PREDICT:
            return model.estimator_spec_predict(features, mesh, mesh_impl,
                                                use_tpu)

        logits, loss = model.mtf_model_fn(features, mesh)
        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf_estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            lr = learning_rate.learning_rate_schedule(hparams)
            tf.summary.scalar("learning_rate", lr)
            mtf_lr = mtf.import_tf_tensor(
                mesh, tf.convert_to_tensor(lr, dtype=tf.float32),
                mtf.Shape([]))
            optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr)
            update_ops = optimizer.apply_grads(var_grads,
                                               graph.trainable_variables)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if logits and mode != tf_estimator.ModeKeys.TRAIN:
            tf_logits = lowering.export_to_tf_tensor(logits)

        if mode == tf_estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                hparams.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

        # EVAL mode
        if mode == tf_estimator.ModeKeys.EVAL:
            tf_logits = lowering.export_to_tf_tensor(logits)
            return model.estimator_spec_eval(features, tf_logits, labels,
                                             tf_loss, restore_hook, use_tpu)

        if use_tpu:
            # TPU host call. Important: need to be called before remove_summaries()
            if hparams.tpu_enable_host_call:
                host_call = t2t_model.create_host_call(hparams.model_dir)
            else:
                host_call = None

            if hparams.warm_start_from:

                def scaffold_fn():
                    t2t_model.initialize_from_ckpt(
                        ckpt_dir=hparams.warm_start_from, hparams=hparams)
                    return tf.train.Scaffold()
            else:
                scaffold_fn = None

            t2t_model.remove_summaries()
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf_estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                host_call=host_call,
                training_hooks=[restore_hook, saver_hook],
                scaffold_fn=scaffold_fn)
        else:
            if hparams.warm_start_from:
                t2t_model.initialize_from_ckpt(
                    ckpt_dir=hparams.warm_start_from, hparams=hparams)
            return tf_estimator.EstimatorSpec(
                tf_estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_chief_hooks=[restore_hook, saver_hook])
        def model_fn(features, labels, mode):
            """BaselineModel model_fn.

      Args:
        features: `Tensor` or `dict` of `Tensor`.
        labels: A `dict` of `Tensor` Objects. Expects to have a key/value pair
          for the key self.label_column_name, "IPS_example_weights_with_label",
          and "IPS_example_weights_without_label".
          IPS stands for inverse propensity score, wherein each example is
          assigned a weight inversely proportionate their propensity of
          appearing in training distribution. Concretely, ips-weight = 1/p(x),
          where p(x) is the probability of x in training distribution.
          In "IPS_without_label", each example is given a weight as the inverse
          propensity score of their subgroup. For example, 1/p("Black Female").
          In "IPS_with_label", each example is assigned a weight as the inverse
          propensity score of their subgroup and class membership. For example,
          1/p("Black Female", "class 0")).
        mode: Defines whether this is training, evaluation or prediction. See
          `ModeKeys`. Currently PREDICT mode is not implemented.

      Returns:
        An instance of `tf.estimator.EstimatorSpec', which encapsulates the
        `mode`, `predictions`, `loss` and the `train_op`. Note that here
        `predictions` is either a `Tensor` or a `dict` of `Tensor` objects,
        representing the prediction of the bianry classification model.
        'loss` is a scalar containing the loss of the step and `train_op` is the
        op for training.
      """

            # Instantiates a tensor with true class labels
            class_labels = labels[self._label_column_name]

            ips_example_weights_with_label = labels[
                IPS_WITH_LABEL_TARGET_COLUMN_NAME]
            ips_example_weights_without_label = labels[
                IPS_WITHOUT_LABEL_TARGET_COLUMN_NAME]

            tf.logging.info('model_fn for mode: {}'.format(mode))

            with tf.name_scope('model'):
                input_layer = tf.feature_column.input_layer(
                    features, self._feature_columns)
                layer = input_layer
                for unit in self._hidden_units:
                    layer = tf.layers.Dense(unit,
                                            activation=self._activation)(layer)
                logits = tf.layers.Dense(1)(layer)
                sigmoid_output = tf.nn.sigmoid(logits, name='sigmoid')
                class_predictions = tf.cast(tf.greater(sigmoid_output, 0.5), tf.float32)  # pylint: disable=line-too-long
                tf.summary.histogram('class_predictions', class_predictions)

            if self._reweighting_type == 'IPS_with_label':
                example_weights = ips_example_weights_with_label
            elif self._reweighting_type == 'IPS_without_label':
                example_weights = ips_example_weights_without_label

            # Initializes Loss Functions
            loss = self._loss(class_labels, logits, example_weights)

            # Sets up dictionaries used for computing performance metrics
            predictions = {
                (self._label_column_name, 'class_ids'):
                tf.reshape(class_predictions, [-1]),
                (self._label_column_name, 'logistic'):
                tf.reshape(sigmoid_output, [-1])
            }

            class_id_kwargs = {
                'labels': class_labels,
                'predictions': class_predictions
            }
            logistics_kwargs = {
                'labels': class_labels,
                'predictions': sigmoid_output
            }

            # EVAL Mode
            if mode == tf_estimator.ModeKeys.EVAL:
                with tf.name_scope('eval_metrics'):
                    eval_metric_ops = {
                        'accuracy':
                        tf.metrics.accuracy(**class_id_kwargs),
                        'precision':
                        tf.metrics.precision(**class_id_kwargs),
                        'recall':
                        tf.metrics.recall(**class_id_kwargs),
                        'fp':
                        tf.metrics.false_positives(**class_id_kwargs),
                        'fn':
                        tf.metrics.false_negatives(**class_id_kwargs),
                        'tp':
                        tf.metrics.true_positives(**class_id_kwargs),
                        'tn':
                        tf.metrics.true_negatives(**class_id_kwargs),
                        'fpr':
                        contrib_metrics.streaming_false_positive_rate(
                            **class_id_kwargs),  # pylint: disable=line-too-long
                        'fnr':
                        contrib_metrics.streaming_false_negative_rate(
                            **class_id_kwargs),  # pylint: disable=line-too-long
                        'auc':
                        tf.metrics.auc(curve='ROC', **logistics_kwargs),
                        'aucpr':
                        tf.metrics.auc(curve='PR', **logistics_kwargs)
                    }

                    # EstimatorSpec object for evaluation
                    estimator_spec = tf_estimator.EstimatorSpec(
                        mode=mode,
                        predictions=predictions,
                        loss=loss,
                        eval_metric_ops=eval_metric_ops)

            # TRAIN Mode
            if mode == tf_estimator.ModeKeys.TRAIN:
                train_op_primary = contrib_layers.optimize_loss(
                    loss=loss,
                    learning_rate=self._learning_rate,
                    global_step=contrib_framework.get_global_step(),
                    optimizer=self._optimizer)

                estimator_spec = tf_estimator.EstimatorSpec(
                    mode=mode,
                    predictions=predictions,
                    loss=loss,
                    train_op=train_op_primary)

            return estimator_spec
    def model_fn(features, labels, mode):
      """BaselineModel model_fn.

      Args:
        features: `Tensor` or `dict` of `Tensor`.
        labels: A `dict` of `Tensor` Objects. Expects to have a key/value pair
          for the key self.label_column_name.
        mode: Defines whether this is training, evaluation or prediction. See
          `ModeKeys`. Currently PREDICT mode is not implemented.

      Returns:
        An instance of `tf.estimator.EstimatorSpec', which encapsulates the
        `mode`, `predictions`, `loss` and the `train_op`. Note that here
        `predictions` is either a `Tensor` or a `dict` of `Tensor` objects,
        representing the prediction of the bianry classification model.
        'loss` is a scalar containing the loss of the step and `train_op` is the
        op for training.
      """

      # Instantiates a tensor with true class labels
      class_labels = labels[self._label_column_name]

      tf.logging.info('model_fn for mode: {}'.format(mode))

      with tf.name_scope('model'):
        input_layer = tf.feature_column.input_layer(features,
                                                    self._feature_columns)
        layer = input_layer
        for unit in self._hidden_units:
          layer = tf.layers.Dense(unit, activation=self._activation)(layer)
        logits = tf.layers.Dense(1)(layer)
        sigmoid_output = tf.nn.sigmoid(logits, name='sigmoid')
        class_predictions = tf.cast(tf.greater(sigmoid_output, 0.5), tf.float32)
        tf.summary.histogram('class_predictions', class_predictions)

      # Initializes Loss Functions
      loss = self._loss(class_labels, logits)
      # Sets up dictionaries used for computing performance metrics
      predictions = {
          (self._label_column_name, 'class_ids'):
              tf.reshape(class_predictions, [-1]),
          (self._label_column_name, 'logistic'):
              tf.reshape(sigmoid_output, [-1])
      }

      class_id_kwargs = {
          'labels': class_labels,
          'predictions': class_predictions
      }
      logistics_kwargs = {'labels': class_labels, 'predictions': sigmoid_output}

      # EVAL Mode
      if mode == tf_estimator.ModeKeys.EVAL:
        with tf.name_scope('eval_metrics'):
          eval_metric_ops = {
              'accuracy': tf.metrics.accuracy(**class_id_kwargs),
              'precision': tf.metrics.precision(**class_id_kwargs),
              'recall': tf.metrics.recall(**class_id_kwargs),
              'fp': tf.metrics.false_positives(**class_id_kwargs),
              'fn': tf.metrics.false_negatives(**class_id_kwargs),
              'tp': tf.metrics.true_positives(**class_id_kwargs),
              'tn': tf.metrics.true_negatives(**class_id_kwargs),
              'fpr': contrib_metrics.streaming_false_positive_rate(**class_id_kwargs),  # pylint: disable=line-too-long
              'fnr': contrib_metrics.streaming_false_negative_rate(**class_id_kwargs),  # pylint: disable=line-too-long
              'auc': tf.metrics.auc(curve='ROC', **logistics_kwargs),
              'aucpr': tf.metrics.auc(curve='PR', **logistics_kwargs)
          }

          # EstimatorSpec object for evaluation
          estimator_spec = tf_estimator.EstimatorSpec(
              mode=mode,
              predictions=predictions,
              loss=loss,
              eval_metric_ops=eval_metric_ops)

      # TRAIN Mode
      if mode == tf_estimator.ModeKeys.TRAIN:
        train_op_primary = contrib_layers.optimize_loss(
            loss=loss,
            learning_rate=self._learning_rate,
            global_step=contrib_framework.get_global_step(),
            optimizer=self._optimizer)

        estimator_spec = tf_estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            loss=loss,
            train_op=train_op_primary)

      return estimator_spec
示例#25
0
def model_function(features, labels, mode, params, embeddings):
    """A model function satisfying the tf.estimator API.

  Args:
    features: Dictionary of feature tensors with keys:
        - question_tok: <string> [batch_size, max_question_len]
        - context_tok: <string> [batch_size, max_num_context, max_context_len]
        - question_tok_len: <int32> [batch_size]
        - num_context: <int32> [batch_size]
        - context_tok_len: <int32> [batch_size]
        - question_tok_wid: <int32> [batch_size, max_question_len]
        - context_tok_wid: <int32> [batch_size, max_num_context,
          max_context_len]
         - long_answer_indices: <int32> [batch_size]
    labels: <int32> [batch_size] for answer index (-1 = NULL).
    mode: One of the keys from tf.estimator.ModeKeys.
    params: Dictionary of hyperparameters.
    embeddings: An embedding_utils.PretrainedWordEmbeddings object.

  Returns:
    estimator_spec: A tf.estimator.EstimatorSpec object.
  """
    del params  # Unused.

    if mode == tf_estimator.ModeKeys.PREDICT:
        # Add a dummy batch dimension if we are exporting the predictor.
        features = {k: tf.expand_dims(v, 0) for k, v in features.items()}

    embedding_weights, embedding_scaffold = embeddings.get_params(
        trainable=False)

    # Features.
    question_tok_len = features["question_tok_len"]
    question_tok_wid = features["question_tok_wid"]
    context_tok_wid = features["context_tok_wid"]
    num_context = features["num_context"]
    context_tok_len = features["context_tok_len"]

    # Truncate the contexts and labels to a certain maximum length.
    context_tok_wid, num_context, context_tok_len = (
        nq_long_utils.truncate_contexts(context_token_ids=context_tok_wid,
                                        num_contexts=num_context,
                                        context_len=context_tok_len,
                                        max_contexts=FLAGS.max_contexts,
                                        max_context_len=FLAGS.max_context_len))

    non_null_context_scores = nq_long_decatt_model.build_model(
        question_tok_wid=question_tok_wid,
        question_lens=question_tok_len,
        context_tok_wid=context_tok_wid,
        context_lens=context_tok_len,
        embedding_weights=embedding_weights,
        mode=mode)

    # Mask out contexts that are padding.
    num_context_mask = tf.log(
        tf.sequence_mask(num_context,
                         tensor_utils.shape(non_null_context_scores, 1),
                         dtype=tf.float32))
    non_null_context_scores += num_context_mask

    # <float> [batch_size, 1]
    null_score = tf.zeros([tf.shape(question_tok_wid)[0], 1])

    # Offset everything by 1 to account for null context.
    # [batch_size, 1 + max_contexts]
    context_scores = tf.concat([null_score, non_null_context_scores], 1)

    if mode != tf_estimator.ModeKeys.PREDICT:
        labels = nq_long_utils.truncate_labels(labels, FLAGS.max_contexts)

        # In the data, NULL is given index -1 but this is not compatible with
        # softmax so shift by 1.
        labels = labels + 1

        # Reweight null examples.
        weights = nq_long_utils.compute_null_weights(labels, FLAGS.null_weight)

        # When computing the loss we take only the first label.
        loss_labels = labels[:, 0]

        # []
        loss = tf.losses.sparse_softmax_cross_entropy(labels=loss_labels,
                                                      logits=context_scores,
                                                      weights=weights)

        optimizer = tf.train.AdagradOptimizer(
            learning_rate=FLAGS.learning_rate)
        train_op = optimizer.minimize(loss=loss,
                                      global_step=tf.train.get_global_step())

        # <int32> [batch_size]
        eval_predictions = tf.to_int32(tf.argmax(context_scores, 1))

        non_null_match, non_null_gold, non_null_predictions = (
            nq_long_utils.compute_match_stats(eval_predictions, labels))

        precision, precision_op = (tf.metrics.mean(
            non_null_match, weights=non_null_predictions))
        recall, recall_op = (tf.metrics.mean(non_null_match,
                                             weights=non_null_gold))

        f1, f1_op = (nq_long_utils.f1_metric(precision=precision,
                                             precision_op=precision_op,
                                             recall=recall,
                                             recall_op=recall_op))

        # Bogus metric until we figure out how to connect Ming Wei's eval code.
        eval_metric_ops = {
            "precision": (precision, precision_op),
            "recall": (recall, recall_op),
            "f1": (f1, f1_op)
        }
    else:
        loss = None
        train_op = None
        eval_metric_ops = {}

    # In the export, we never predict NULL since the eval metric will compute the
    # best possible F1.
    export_long_answer_idx = tf.to_int32(tf.argmax(non_null_context_scores, 1))
    export_long_answer_score = tf.reduce_max(non_null_context_scores, 1)
    predictions = dict(idx=export_long_answer_idx,
                       score=export_long_answer_score)

    if mode == tf_estimator.ModeKeys.PREDICT:
        # Remove the dummy batch dimension if we are exporting the predictor.
        predictions = {k: tf.squeeze(v, 0) for k, v in predictions.items()}

    estimator_spec = tf_estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        predictions=predictions,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        scaffold=embedding_scaffold)

    return estimator_spec
示例#26
0
    def _model_fn(features, labels, mode, params):
        """Constructs the model function.

    Args:
      features: Dictionary of input features.
      labels: Tensor of labels if mode is `TRAIN` or `EVAL`, otherwise `None`.
      mode: ModeKey object (`TRAIN` or `EVAL`).
      params: Parameter dictionary passed from the Estimator object.

    Returns:
      An EstimatorSpec object that encapsulates the model and its serving
        configurations.
    """
        del params  # Unused.

        def process_images(images):
            """Closure for processing images with fixed metadata."""
            return process.process(images, features['red_gain'],
                                   features['blue_gain'], features['cam2rgb'])

        denoised_img = inference_fn(features['noisy_img'],
                                    features['variance'])

        noisy_img = process_images(features['noisy_img'])
        denoised_img = process_images(denoised_img)
        truth_img = process_images(labels)

        if mode in [tf_estimator.ModeKeys.TRAIN, tf_estimator.ModeKeys.EVAL]:
            loss = tf.losses.absolute_difference(truth_img, denoised_img)
        else:
            loss = None

        if mode == tf_estimator.ModeKeys.TRAIN:
            optimizer = tf.train.AdamOptimizer(
                learning_rate=hparams.learning_rate)
            train_op = contrib_layers.optimize_loss(
                loss=loss,
                global_step=tf.train.get_global_step(),
                learning_rate=None,
                optimizer=optimizer,
                name='')  # Prevents scope prefix.
        else:
            train_op = None

        if mode == tf_estimator.ModeKeys.EVAL:
            eval_metric_ops = {'PSNR': psnr(truth_img, denoised_img)}

            def summary(images, name):
                """As a hack, saves image summaries by adding to `eval_metric_ops`."""
                images = tf.saturate_cast(images * 255 + 0.5, tf.uint8)
                eval_metric_ops[name] = (tf.summary.image(name,
                                                          images,
                                                          max_outputs=2),
                                         tf.no_op())

            summary(noisy_img, 'Noisy')
            summary(denoised_img, 'Denoised')
            summary(truth_img, 'Truth')

            diffs = (denoised_img - truth_img + 1.0) / 2.0
            summary(diffs, 'Diffs')

        else:
            eval_metric_ops = None

        return tf_estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metric_ops=eval_metric_ops)
def model_fn_w_pruning(features, labels, mode, params):
  """The model_fn for ResNet-50 with pruning.

  Args:
    features: A float32 batch of images.
    labels: A int32 batch of labels.
    mode: Specifies whether training or evaluation.
    params: parameters passed to the eval function.

  Returns:
    A EstimatorSpec for the model
  """

  task = params["task"]

  if task in ["pie_dataset_gen", "imagenet_training", "imagenet_predictions"]:
    images = features["image_raw"]
    labels = features["label"]
  else:
    images = features

  if task in [
      "pie_dataset_gen", "robustness_imagenet_c", "robustness_imagenet_a",
      "ckpt_prediction"
  ]:
    human_labels = features["human_label"]

  mean_rgb = params["mean_rgb"]
  stddev_rgb = params["stddev_rgb"]

  # Normalize the image to zero mean and unit variance.
  images -= tf.constant(mean_rgb, shape=[1, 1, 3], dtype=images.dtype)
  images /= tf.constant(stddev_rgb, shape=[1, 1, 3], dtype=images.dtype)

  network = resnet_model.resnet_50(
      num_classes=params["num_label_classes"],
      pruning_method=params["pruning_method"],
      data_format="channels_last")

  logits = network(
      inputs=images, is_training=(mode == tf_estimator.ModeKeys.TRAIN))
  one_hot_labels = tf.one_hot(labels, params["num_label_classes"])

  cross_entropy = tf.losses.softmax_cross_entropy(
      logits=logits,
      onehot_labels=one_hot_labels,
      label_smoothing=params["label_smoothing"])

  # Add weight decay to the loss for non-batch-normalization variables.
  loss = cross_entropy + params["weight_decay"] * tf.add_n([
      tf.nn.l2_loss(v)
      for v in tf.trainable_variables()
      if "batch_normalization" not in v.name
  ])

  # we run predictions on gpu since ordering is very important and
  # thus we need to run with batch size 1 (not enabled on tpu)
  if mode == tf_estimator.ModeKeys.PREDICT:
    train_op = None
    eval_metrics = None
    predicted_probability = tf.cast(
        tf.reduce_max(tf.nn.softmax(logits, name="softmax"), axis=1),
        tf.float32)

    _, top_5_indices = tf.nn.top_k(tf.to_float(logits), k=5)

    predictions = {
        "predictions": tf.argmax(logits, axis=1),
        "true_class": labels,
        "predicted_probability": predicted_probability,
        "top_5_indices": top_5_indices
    }

  if mode == tf_estimator.ModeKeys.TRAIN:
    train_op = train_function(params, loss)
    eval_metrics = None
    predictions = None

  if mode == tf_estimator.ModeKeys.EVAL:
    train_op = None
    predictions = None
    params_eval = {
        "num_label_classes": params["num_label_classes"],
        "log_class_level_summaries": False
    }
    eval_metrics = class_level_metrics.create_eval_metrics(
        labels, logits, human_labels, params_eval)

  return tf_estimator.EstimatorSpec(
      predictions=predictions,
      mode=mode,
      loss=loss,
      train_op=train_op,
      eval_metric_ops=eval_metrics)
示例#28
0
def model_function(features, labels, mode, params, embeddings):
    """A model function satisfying the tf.estimator API.

  Args:
    features: Dictionary of feature tensors with keys:
        - question: <string> [batch_size, max_question_len]
        - question_len: <int32> [batch_size]
        - question_cid: <int32> [batch_size, max_question_len, max_chars]
        - question_wid: <int32> [batch_size, max_question_len]
        - context: <string> [batch_size, max_context_len]
        - context_len: <int32> [batch_size]
        - context_cid: <int32> [batch_size, max_context_len, max_chars]
        - context_wid: <int32> [batch_size, max_context_len]
        - answer_start: <int32> [batch_size]
        - answer_end: <int32> [batch_size]
    labels: Pair of tensors containing the answer start and answer end.
    mode: One of the keys from tf.estimator.ModeKeys.
    params: Unused parameter dictionary.
    embeddings: An embedding_utils.PretrainedWordEmbeddings object.

  Returns:
    estimator_spec: A tf.estimator.EstimatorSpec object.
  """
    del params

    if mode == tf_estimator.ModeKeys.PREDICT:
        # Add a dummy batch dimension if we are exporting the predictor.
        features = {k: tf.expand_dims(v, 0) for k, v in features.items()}

    embedding_weights, embedding_scaffold = embeddings.get_params(
        trainable=False)

    def _embed(prefix):
        """Embed the input text based and word and character IDs."""
        word_emb = tf.nn.embedding_lookup(embedding_weights,
                                          features[prefix + "_wid"])
        char_emb = common_layers.character_cnn(
            char_ids=features[prefix + "_cid"],
            emb_size=FLAGS.char_emb_size,
            kernel_width=FLAGS.char_kernel_width,
            num_filters=FLAGS.num_char_filters)
        concat_emb = tf.concat([word_emb, char_emb], -1)

        if mode == tf_estimator.ModeKeys.TRAIN:
            concat_emb = tf.nn.dropout(concat_emb, 1.0 - FLAGS.dropout_ratio)
        return concat_emb

    with tf.variable_scope("embed"):
        # [batch_size, max_question_len, hidden_size]
        question_emb = _embed("question")

    with tf.variable_scope("embed", reuse=True):
        # [batch_size, max_context_len, hidden_size]
        context_emb = _embed("context")

    # [batch_size, max_context_len]
    start_logits, end_logits = document_reader.score_endpoints(
        question_emb=question_emb,
        question_len=features["question_len"],
        context_emb=context_emb,
        context_len=features["context_len"],
        hidden_size=FLAGS.hidden_size,
        num_layers=FLAGS.num_layers,
        dropout_ratio=FLAGS.dropout_ratio,
        mode=mode,
        use_cudnn=False if mode == tf_estimator.ModeKeys.PREDICT else None)

    if mode != tf_estimator.ModeKeys.PREDICT:
        # [batch_size]
        start_labels, end_labels = labels

        # Since we truncate long contexts, some of the labels will not be
        # recoverable. In that case, we mask these invalid labels.
        valid_start_labels = tf.less(start_labels, features["context_len"])
        valid_end_labels = tf.less(end_labels, features["context_len"])
        tf.summary.histogram("valid_start_labels",
                             tf.to_float(valid_start_labels))
        tf.summary.histogram("valid_end_labels", tf.to_float(valid_end_labels))

        dummy_labels = tf.zeros_like(start_labels)

        # []
        start_loss = tf.losses.sparse_softmax_cross_entropy(
            labels=tf.where(valid_start_labels, start_labels, dummy_labels),
            logits=start_logits,
            weights=tf.to_float(valid_start_labels),
            reduction=tf.losses.Reduction.MEAN)
        end_loss = tf.losses.sparse_softmax_cross_entropy(
            labels=tf.where(valid_end_labels, end_labels, dummy_labels),
            logits=end_logits,
            weights=tf.to_float(valid_end_labels),
            reduction=tf.losses.Reduction.MEAN)
        loss = start_loss + end_loss
    else:
        loss = None

    if mode == tf_estimator.ModeKeys.TRAIN:
        optimizer = tf.train.AdamOptimizer()
        gradients, variables = list(zip(*optimizer.compute_gradients(loss)))
        gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
        train_op = optimizer.apply_gradients(
            grads_and_vars=list(zip(gradients, variables)),
            global_step=tf.train.get_global_step())
    else:
        # Don't build the train_op unnecessarily, since the ADAM variables can cause
        # problems with loading checkpoints on CPUs.
        train_op = None

    batch_size, max_context_len = tensor_utils.shape(features["context_wid"])
    tf.summary.histogram("batch_size", batch_size)
    tf.summary.histogram("non_padding",
                         features["context_len"] / max_context_len)

    # [batch_size], [batch_size]
    start_predictions, end_predictions, predicted_score = (
        span_utils.max_scoring_span(start_logits, end_logits))

    # [batch_size, 2]
    predictions = dict(start_idx=start_predictions,
                       end_idx=(end_predictions + 1),
                       score=predicted_score)

    if mode == tf_estimator.ModeKeys.PREDICT:
        # Remove the dummy batch dimension if we are exporting the predictor.
        predictions = {k: tf.squeeze(v, 0) for k, v in predictions.items()}

    if mode == tf_estimator.ModeKeys.EVAL:
        text_summary = get_text_summary(question=features["question"],
                                        context=features["context"],
                                        start_predictions=start_predictions,
                                        end_predictions=end_predictions)

        # TODO(kentonl): Replace this with @mingweichang's official eval script.
        exact_match = tf.logical_and(tf.equal(start_predictions, start_labels),
                                     tf.equal(end_predictions, end_labels))

        eval_metric_ops = dict(exact_match=tf.metrics.mean(exact_match),
                               text_summary=(text_summary, tf.no_op()))
    else:
        eval_metric_ops = None

    estimator_spec = tf_estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        predictions=predictions,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        scaffold=embedding_scaffold)

    return estimator_spec
示例#29
0
def resnet_model_fn_w_pruning(features, labels, mode, params):
    """The model_fn for ResNet-50 with pruning.

  Args:
    features: A float32 batch of images.
    labels: A int32 batch of labels.
    mode: Specifies whether training or evaluation.
    params: Dictionary of parameters passed to the model.

  Returns:
    A TPUEstimatorSpec for the model
  """

    width = 1. if FLAGS.width <= 0 else FLAGS.width
    if isinstance(features, dict):
        features = features['feature']

    if FLAGS.data_format == 'channels_first':
        assert not FLAGS.transpose_input  # channels_first only for GPU
        features = tf.transpose(features, [0, 3, 1, 2])

    if FLAGS.transpose_input and mode != tf_estimator.ModeKeys.PREDICT:
        features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

    # Normalize the image to zero mean and unit variance.
    features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
    features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)

    pruning_method = params['pruning_method']
    use_tpu = params['use_tpu']
    log_alpha_threshold = params['log_alpha_threshold']

    def build_network():
        """Construct the network in the graph."""
        model_pruning_method = pruning_method
        if pruning_method == 'scratch':
            model_pruning_method = 'threshold'

        network = resnet_model.resnet_v1_(
            resnet_depth=FLAGS.resnet_depth,
            num_classes=FLAGS.num_label_classes,
            # we need to construct the model with the pruning masks, but they won't
            # be updated if we're doing scratch training
            pruning_method=model_pruning_method,
            init_method=FLAGS.init_method,
            width=width,
            prune_first_layer=FLAGS.prune_first_layer,
            prune_last_layer=FLAGS.prune_last_layer,
            data_format=FLAGS.data_format,
            end_sparsity=FLAGS.end_sparsity,
            clip_log_alpha=FLAGS.clip_log_alpha,
            log_alpha_threshold=log_alpha_threshold,
            weight_decay=FLAGS.weight_decay)
        return network(inputs=features,
                       is_training=(mode == tf_estimator.ModeKeys.TRAIN))

    if FLAGS.precision == 'bfloat16':
        with contrib_tpu.bfloat16_scope():
            logits = build_network()
        logits = tf.cast(logits, tf.float32)
    elif FLAGS.precision == 'float32':
        logits = build_network()

    if mode == tf_estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf_estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf_estimator.export.PredictOutput(predictions)
            })

    output_dir = params['output_dir']  # pylint: disable=unused-variable

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes)

    # make sure we reuse the same label smoothing parameter is we're doing
    # scratch / lottery ticket experiments.
    label_smoothing = FLAGS.label_smoothing
    if FLAGS.pruning_method == 'scratch':
        label_smoothing = float(FLAGS.load_mask_dir.split('/')[15])
    loss = tf.losses.softmax_cross_entropy(logits=logits,
                                           onehot_labels=one_hot_labels,
                                           label_smoothing=label_smoothing)
    # Add regularization loss term
    loss += tf.losses.get_regularization_loss()

    if pruning_method == 'variational_dropout':
        reg_loss = utils.variational_dropout_dkl_loss(
            reg_scalar=FLAGS.reg_scalar,
            start_reg_ramp_up=FLAGS.sparsity_begin_step,
            end_reg_ramp_up=FLAGS.sparsity_end_step,
            warm_up=FLAGS.is_warm_up,
            use_tpu=use_tpu)
        loss += reg_loss
        tf.losses.add_loss(reg_loss, loss_collection=tf.GraphKeys.LOSSES)
    elif pruning_method == 'l0_regularization':
        reg_loss = utils.l0_regularization_loss(
            reg_scalar=FLAGS.reg_scalar,
            start_reg_ramp_up=FLAGS.sparsity_begin_step,
            end_reg_ramp_up=FLAGS.sparsity_end_step,
            warm_up=FLAGS.is_warm_up,
            use_tpu=use_tpu)
        loss += reg_loss
        tf.losses.add_loss(reg_loss, loss_collection=tf.GraphKeys.LOSSES)

    host_call = None
    if mode == tf_estimator.ModeKeys.TRAIN:
        host_call, train_op = train_function(pruning_method, loss, output_dir,
                                             use_tpu)

    else:
        train_op = None

    eval_metrics = None
    if mode == tf_estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits):
            """Calculate eval metrics."""
            logging.info('In metric function')
            eval_metrics = {}
            predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            eval_metrics['top_5_eval_accuracy'] = tf.metrics.mean(in_top_5)
            eval_metrics['eval_accuracy'] = tf.metrics.accuracy(
                labels=labels, predictions=predictions)

            return eval_metrics

        def vd_metric_fn(labels, logits, global_sparsity):
            eval_metrics = metric_fn(labels, logits)
            eval_metrics['global_sparsity'] = tf.metrics.mean(global_sparsity)
            return eval_metrics

        tensors = [labels, logits]
        metric_function = metric_fn

        if FLAGS.pruning_method == 'variational_dropout':
            batch_size = labels.shape[0]
            ones = tf.ones([batch_size, 1])
            mask_metrics = utils.add_vd_pruning_summaries(
                threshold=FLAGS.log_alpha_threshold)
            tensors.append(mask_metrics['global_sparsity'] * ones)
            metric_function = vd_metric_fn

        eval_metrics = (metric_function, tensors)

    # define a custom scaffold function to enable initializing the mask from an
    # already trained checkpoint.
    def initialize_mask_from_ckpt(ckpt_path):
        """Load mask from an existing checkpoint."""
        model_dir = FLAGS.output_dir
        already_has_ckpt = model_dir and tf.train.latest_checkpoint(
            model_dir) is not None
        if already_has_ckpt:
            tf.logging.info(
                'Training already started on this model, not loading masks from'
                'previously trained model')
            return

        reader = tf.train.NewCheckpointReader(ckpt_path)
        mask_names = reader.get_variable_to_shape_map().keys()
        mask_names = [x for x in mask_names if x.endswith('mask')]

        variable_map = {}
        for var in tf.global_variables():
            var_name = var.name.split(':')[0]
            if var_name in mask_names:
                tf.logging.info('Loading mask variable from checkpoint: %s',
                                var_name)
                variable_map[var_name] = var
            elif 'mask' in var_name:
                tf.logging.info(
                    'Cannot find mask variable in checkpoint, skipping: %s',
                    var_name)
        tf.train.init_from_checkpoint(ckpt_path, variable_map)

    def initialize_parameters_from_ckpt(ckpt_path):
        """Load parameters from an existing checkpoint."""
        model_dir = FLAGS.output_dir
        already_has_ckpt = model_dir and tf.train.latest_checkpoint(
            model_dir) is not None
        if already_has_ckpt:
            tf.logging.info(
                'Training already started on this model, not loading masks from'
                'previously trained model')
            return

        reader = tf.train.NewCheckpointReader(ckpt_path)
        param_names = reader.get_variable_to_shape_map().keys()
        param_names = [x for x in param_names if not x.endswith('mask')]

        variable_map = {}
        for var in tf.global_variables():
            var_name = var.name.split(':')[0]
            if var_name in param_names:
                tf.logging.info(
                    'Loading parameter variable from checkpoint: %s', var_name)
                variable_map[var_name] = var
            elif 'mask' not in var_name:
                tf.logging.info(
                    'Cannot find parameter variable in checkpoint, skipping: %s',
                    var_name)
        tf.train.init_from_checkpoint(ckpt_path, variable_map)

    if FLAGS.pruning_method == 'scratch':
        if FLAGS.load_mask_dir:

            def scaffold_fn():
                initialize_mask_from_ckpt(FLAGS.load_mask_dir)
                if FLAGS.initial_value_checkpoint:
                    initialize_parameters_from_ckpt(
                        FLAGS.initial_value_checkpoint)
                return tf.train.Scaffold()
        else:
            raise ValueError(
                'Must supply a mask directory to use scratch method')
    else:
        scaffold_fn = None

    return contrib_tpu.TPUEstimatorSpec(mode=mode,
                                        loss=loss,
                                        train_op=train_op,
                                        host_call=host_call,
                                        eval_metrics=eval_metrics,
                                        scaffold_fn=scaffold_fn)
        def model_fn(features, labels, mode):
            """AdversarialReweightingModel model_fn.

      Args:
        features: `Tensor` or `dict` of `Tensor`.
        labels: A `dict` of `Tensor` Objects. Expects to have a key/value pair
          for the key self.label_column_name.
        mode: Defines whether this is training, evaluation or prediction. See
          `ModeKeys`. Currently PREDICT mode is not implemented.

      Returns:
        An instance of `tf.estimator.EstimatorSpec', which encapsulates the
        `mode`, `predictions`, `loss` and the `train_op`. Note that here
        `predictions` is either a `Tensor` or a `dict` of `Tensor` objects,
        representing the prediction of the bianry classification model.
        'loss` is a scalar containing the loss of the step and `train_op` is the
        op for training.
      """

            # Instantiates a tensor with weight for positive class examples only
            pos_weights = tf.cast(tf.equal(labels[self._label_column_name], 1),
                                  dtype=tf.float32)

            # Instantiates a tensor with true class labels
            class_labels = labels[self._label_column_name]

            # Initialize a global step variable used for alternate training
            current_step = self._get_or_create_global_step_var()

            if mode == tf_estimator.ModeKeys.EVAL:
                tf.logging.info('model_fn: EVAL, {}'.format(mode))
            elif mode == tf_estimator.ModeKeys.TRAIN:
                tf.logging.info('model_fn: TRAIN, {}'.format(mode))

            # Creates a DNN architecture for primary binary classification task
            with tf.name_scope('primary_NN'):
                with tf.variable_scope('primary'):
                    input_layer = tf.feature_column.input_layer(
                        features, self._feature_columns)
                    h1 = tf.layers.Dense(
                        self._primary_hidden_units[0],
                        activation=self._activation)(input_layer)
                    h2 = tf.layers.Dense(self._primary_hidden_units[1],
                                         activation=self._activation)(h1)
                    logits = tf.layers.Dense(1)(h2)
                    sigmoid_output = tf.nn.sigmoid(logits, name='sigmoid')
                    class_predictions = tf.cast(
                        tf.greater(sigmoid_output, 0.5), tf.float32)
                    tf.summary.histogram('class_predictions',
                                         class_predictions)

            # Creates a network architecture for the adversarial regression task
            with tf.name_scope('adversary_NN'):
                with tf.variable_scope('adversary'):
                    # Gets adversary features and features columns
                    adversarial_features, adversary_feature_columns = self._get_adversary_features_and_feature_columns(features, labels)  # pylint: disable=line-too-long
                    adv_input_layer = tf.feature_column.input_layer(
                        adversarial_features, adversary_feature_columns)
                    adv_h1 = tf.layers.Dense(
                        self._adversary_hidden_units[0])(adv_input_layer)
                    adv_output_layer = tf.layers.Dense(1,
                                                       use_bias=True)(adv_h1)
                    example_weights = tf.cond(
                        tf.greater(current_step, self._pretrain_steps),
                        true_fn=lambda: self._compute_example_weights(
                            adv_output_layer),
                        false_fn=lambda: tf.ones_like(class_labels))

            # Adds summary variables to tensorboard
            with tf.name_scope('example_weights'):
                tf.summary.histogram('example_weights', example_weights)
                tf.summary.histogram('label', class_labels)

            # Initializes Loss Functions
            primary_loss = self._primary_loss(class_labels, logits,
                                              example_weights)
            adversary_loss = self._adversary_loss(class_labels, logits,
                                                  pos_weights, example_weights,
                                                  self._adversary_loss_type)

            # Sets up dictionaries used for computing performance metrics
            predictions = {
                (self._label_column_name, 'class_ids'):
                tf.reshape(class_predictions, [-1]),
                (self._label_column_name, 'logistic'):
                tf.reshape(sigmoid_output, [-1]),
                ('example_weights'):
                tf.reshape(example_weights, [-1])
            }

            class_id_kwargs = {
                'labels': class_labels,
                'predictions': class_predictions
            }
            logistics_kwargs = {
                'labels': class_labels,
                'predictions': sigmoid_output
            }

            # EVAL Mode
            if mode == tf_estimator.ModeKeys.EVAL:
                with tf.name_scope('eval_metrics'):
                    eval_metric_ops = {
                        'accuracy':
                        tf.metrics.accuracy(**class_id_kwargs),
                        'precision':
                        tf.metrics.precision(**class_id_kwargs),
                        'recall':
                        tf.metrics.recall(**class_id_kwargs),
                        'fp':
                        tf.metrics.false_positives(**class_id_kwargs),
                        'fn':
                        tf.metrics.false_negatives(**class_id_kwargs),
                        'tp':
                        tf.metrics.true_positives(**class_id_kwargs),
                        'tn':
                        tf.metrics.true_negatives(**class_id_kwargs),
                        'fpr':
                        contrib_metrics.streaming_false_positive_rate(
                            **class_id_kwargs),  # pylint: disable=line-too-long
                        'fnr':
                        contrib_metrics.streaming_false_negative_rate(
                            **class_id_kwargs),  # pylint: disable=line-too-long
                        'auc':
                        tf.metrics.auc(curve='ROC', **logistics_kwargs),
                        'aucpr':
                        tf.metrics.auc(curve='PR', **logistics_kwargs)
                    }

                    # EstimatorSpec object for evaluation
                    estimator_spec = tf_estimator.EstimatorSpec(
                        mode=mode,
                        predictions=predictions,
                        loss=primary_loss,
                        eval_metric_ops=eval_metric_ops)

            # TRAIN Mode
            if mode == tf_estimator.ModeKeys.TRAIN:
                # Filters trainable variables for each task
                all_trainable_vars = tf.trainable_variables()
                primary_trainable_vars = [
                    v for v in all_trainable_vars if 'primary' in v.op.name
                ]
                adversary_trainable_vars = [
                    v for v in all_trainable_vars if 'adversary' in v.op.name
                ]

                # TRAIN_OP for adversary DNN
                train_op_adversary = contrib_layers.optimize_loss(
                    loss=adversary_loss,
                    variables=adversary_trainable_vars,
                    global_step=contrib_framework.get_global_step(),
                    learning_rate=self._adversary_learning_rate,
                    optimizer=self._optimizer)

                # TRAIN_OP for primary DNN
                train_op_primary = contrib_layers.optimize_loss(
                    loss=primary_loss,
                    variables=primary_trainable_vars,
                    global_step=contrib_framework.get_global_step(),
                    learning_rate=self._primary_learning_rate,
                    optimizer=self._optimizer)

                # Upto ``pretrain_steps'' trains primary only.
                # Beyond ``pretrain_steps'' alternates between primary and adversary.
                estimator_spec = tf_estimator.EstimatorSpec(
                    mode=mode,
                    predictions=predictions,
                    loss=primary_loss + adversary_loss,
                    train_op=tf.cond(
                        tf.greater(current_step, self._pretrain_steps),
                        true_fn=lambda: tf.group(
                            [train_op_primary, train_op_adversary]),  # pylint: disable=line-too-long
                        false_fn=lambda: tf.group([train_op_primary])))

            return estimator_spec