def model(use_custom_op, inputs, targets):
    weights = tf.get_variable("weights",
                              shape=[3, 1],
                              initializer=tf.zeros_initializer(),
                              dtype=tf.float16)
    # Forward function:
    preds = tf.matmul(inputs, weights)

    sigmoid = 0.5 * (tf.math.tanh(preds) + 1)
    probs = sigmoid * targets + (1 - sigmoid) * (1 - targets)
    training_loss = tf.math.reduce_sum(-tf.math.log(probs))

    gradOfLossWrtInput = tf.gradients(training_loss, [inputs])[0]

    # Optimiser:
    if use_custom_op == 'sgd':
        opt = tf.train.GradientDescentOptimizer(learning_rate=0.05)
    elif use_custom_op == 'momentum':
        opt = tf.train.MomentumOptimizer(learning_rate=0.05,
                                         momentum=0.9,
                                         use_nesterov=False)
    elif use_custom_op == 'lamb':
        opt = _opt.LAMBOptimizer(0.05, high_precision=False)
    elif use_custom_op == 'adamw':
        opt = _opt.AdamWeightDecayOptimizer(0.05)
    train_op = opt.minimize(training_loss)

    return training_loss, weights, gradOfLossWrtInput, train_op
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for Estimator."""
    if FLAGS.verbose_logging:
        tf.compat.v1.logging.info("*** Features ***")
        for name in sorted(features.keys()):
          tf.compat.v1.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    unique_ids = features["unique_ids"]
    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    if not is_training and FLAGS.use_trt:
        trt_graph = get_frozen_tftrt_model(bert_config, input_ids.shape, use_one_hot_embeddings, init_checkpoint)
        (start_logits, end_logits) = tf.import_graph_def(trt_graph,
                input_map={'input_ids':input_ids, 'input_mask':input_mask, 'segment_ids':segment_ids},
                return_elements=['unstack:0', 'unstack:1'],
                name='')
        predictions = {
            "unique_ids": unique_ids,
            "start_logits": start_logits,
            "end_logits": end_logits,
        }
        output_spec = tf.estimator.EstimatorSpec(
            mode=mode, predictions=predictions)
        return output_spec

    (start_logits, end_logits) = create_model(
        bert_config=bert_config,
        is_training=is_training,
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        use_one_hot_embeddings=use_one_hot_embeddings)

    tvars = tf.trainable_variables()

    initialized_variable_names = {}
    if init_checkpoint and (hvd is None or hvd.rank() == 0):
      (assignment_map, initialized_variable_names
      ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
      
      tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    if FLAGS.verbose_logging:
        tf.compat.v1.logging.info("**** Trainable Variables ****")
        for var in tvars:
          init_string = ""
          if var.name in initialized_variable_names:
            init_string = ", *INIT_FROM_CKPT*"
          tf.compat.v1.logging.info(" %d name = %s, shape = %s%s", 0 if hvd is None else hvd.rank(), var.name, var.shape,
                          init_string)


    output_spec = None
    if mode == tf.estimator.ModeKeys.TRAIN:
      seq_length = modeling.get_shape_list(input_ids)[1]

      def compute_loss(logits, positions):
        one_hot_positions = tf.one_hot(
            positions, depth=seq_length, dtype=tf.float32)
        log_probs = tf.nn.log_softmax(logits, axis=-1)
        loss = -tf.reduce_mean(
            tf.reduce_sum(one_hot_positions * log_probs, axis=-1))
        return loss

      start_positions = features["start_positions"]
      end_positions = features["end_positions"]

      start_loss = compute_loss(start_logits, start_positions)
      end_loss = compute_loss(end_logits, end_positions)

      total_loss = (start_loss + end_loss) / 2.0

      train_op = optimization.create_optimizer(
          total_loss, learning_rate, num_train_steps, num_warmup_steps, hvd, False, amp, FLAGS.num_accumulation_steps)

      output_spec = tf.estimator.EstimatorSpec(
          mode=mode,
          loss=total_loss,
          train_op=train_op)
    elif mode == tf.estimator.ModeKeys.PREDICT:

      dummy_op = tf.no_op()
      # Need to call mixed precision graph rewrite if fp16 to enable graph rewrite
      if amp:
        loss_scaler = tf.train.experimental.FixedLossScale(1)
        dummy_op = tf.train.experimental.enable_mixed_precision_graph_rewrite(
            optimization.LAMBOptimizer(learning_rate=0.0), loss_scaler)

      predictions = {
          "unique_ids": unique_ids,
          "start_logits": start_logits,
          "end_logits": end_logits,
      }
      output_spec = tf.estimator.EstimatorSpec(
          mode=mode, predictions=predictions)
    else:
      raise ValueError(
          "Only TRAIN and PREDICT modes are supported: %s" % (mode))

    return output_spec
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.compat.v1.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.compat.v1.logging.info("  name = %s, shape = %s" %
                                      (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        is_real_example = None
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        (total_loss, per_example_loss, logits,
         probabilities) = create_model(bert_config, is_training, input_ids,
                                       input_mask, segment_ids, label_ids,
                                       num_labels, use_one_hot_embeddings)

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint and (hvd is None or hvd.rank() == 0):
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.compat.v1.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.compat.v1.logging.info("  name = %s, shape = %s%s", var.name,
                                      var.shape, init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:

            train_op = optimization.create_optimizer(total_loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, hvd,
                                                     False, amp)

            output_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                     loss=total_loss,
                                                     train_op=train_op)
        elif mode == tf.estimator.ModeKeys.EVAL:

            dummy_op = tf.no_op()
            # Need to call mixed precision graph rewrite if fp16 to enable graph rewrite
            if amp:
                loss_scaler = tf.train.experimental.FixedLossScale(1)
                dummy_op = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                    optimization.LAMBOptimizer(learning_rate=0.0), loss_scaler)

            def metric_fn(per_example_loss, label_ids, logits,
                          is_real_example):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(labels=label_ids,
                                               predictions=predictions,
                                               weights=is_real_example)
                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)
                return {
                    "eval_accuracy": accuracy,
                    "eval_loss": loss,
                }

            eval_metric_ops = metric_fn(per_example_loss, label_ids, logits,
                                        is_real_example)
            output_spec = tf.estimator.EstimatorSpec(
                mode=mode, loss=total_loss, eval_metric_ops=eval_metric_ops)
        else:
            dummy_op = tf.no_op()
            # Need to call mixed precision graph rewrite if fp16 to enable graph rewrite
            if amp:
                dummy_op = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                    optimization.LAMBOptimizer(learning_rate=0.0))

            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={"probabilities":
                             probabilities})  #predicts)#probabilities)
        return output_spec
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for Estimator."""

    def metric_fn(per_example_loss, label_ids, logits):
        predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
        if task_name == "cola":
            FN, FN_op = tf.metrics.false_negatives(labels=label_ids, predictions=predictions)
            FP, FP_op = tf.metrics.false_positives(labels=label_ids, predictions=predictions)
            TP, TP_op = tf.metrics.true_positives(labels=label_ids, predictions=predictions)
            TN, TN_op = tf.metrics.true_negatives(labels=label_ids, predictions=predictions)

            MCC = (TP * TN - FP * FN) / ((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN)) ** 0.5
            MCC_op = tf.group(FN_op, TN_op, TP_op, FP_op, tf.identity(MCC, name="MCC"))
            return {"MCC": (MCC, MCC_op)}
        elif task_name == "mrpc":
            accuracy = tf.metrics.accuracy(
                labels=label_ids, predictions=predictions)
            loss = tf.metrics.mean(values=per_example_loss)
            f1 = tf_metrics.f1(labels=label_ids, predictions=predictions, num_classes=2, pos_indices=[1])
            return {
                "eval_accuracy": accuracy,
                "eval_f1": f1,
                "eval_loss": loss,
            }
        else:
            accuracy = tf.metrics.accuracy(
                labels=label_ids, predictions=predictions)
            loss = tf.metrics.mean(values=per_example_loss)
            return {
                "eval_accuracy": accuracy,
                "eval_loss": loss,
            }
    tf.compat.v1.logging.info("*** Features ***")
    tf.compat.v1.logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf.compat.v1.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]
    label_ids = features["label_ids"]

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    if not is_training and FLAGS.use_trt:
        trt_graph = get_frozen_tftrt_model(bert_config, input_ids.shape, num_labels, use_one_hot_embeddings, init_checkpoint)
        (total_loss, per_example_loss, logits, probabilities)  = tf.import_graph_def(trt_graph,
                input_map={'input_ids':input_ids, 'input_mask':input_mask, 'segment_ids':segment_ids, 'label_ids':label_ids},
                return_elements=['loss/cls_loss:0', 'loss/cls_per_example_loss:0', 'loss/cls_logits:0', 'loss/cls_probabilities:0'],
                name='')
        if mode == tf.estimator.ModeKeys.PREDICT:
            predictions = {"probabilities": probabilities}
            output_spec = tf.estimator.EstimatorSpec(
                mode=mode, predictions=predictions)
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metric_ops = metric_fn(per_example_loss, label_ids, logits)
            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metric_ops=eval_metric_ops)
        return output_spec
    (total_loss, per_example_loss, logits, probabilities) = create_model(
        bert_config, is_training, input_ids, input_mask, segment_ids, label_ids,
        num_labels, use_one_hot_embeddings)

    tvars = tf.trainable_variables()
    initialized_variable_names = {}
    if init_checkpoint and (hvd is None or hvd.rank() == 0):
      (assignment_map, initialized_variable_names
      ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
      tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    if FLAGS.verbose_logging:
        tf.compat.v1.logging.info("**** Trainable Variables ****")
        for var in tvars:
          init_string = ""
          if var.name in initialized_variable_names:
            init_string = ", *INIT_FROM_CKPT*"
          tf.compat.v1.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                          init_string)

    output_spec = None
    if mode == tf.estimator.ModeKeys.TRAIN:

      train_op = optimization.create_optimizer(
          total_loss, learning_rate, num_train_steps, num_warmup_steps,
          hvd, False, FLAGS.amp, FLAGS.num_accumulation_steps, FLAGS.optimizer_type)
      output_spec = tf.estimator.EstimatorSpec(
          mode=mode,
          loss=total_loss,
          train_op=train_op)
    elif mode == tf.estimator.ModeKeys.EVAL:
      dummy_op = tf.no_op()
      # Need to call mixed precision graph rewrite if fp16 to enable graph rewrite
      if FLAGS.amp:
        loss_scaler = tf.train.experimental.FixedLossScale(1)
        dummy_op = tf.train.experimental.enable_mixed_precision_graph_rewrite(
            optimization.LAMBOptimizer(learning_rate=0.0), loss_scaler)
      eval_metric_ops = metric_fn(per_example_loss, label_ids, logits)
      output_spec = tf.estimator.EstimatorSpec(
          mode=mode,
          loss=total_loss,
          eval_metric_ops=eval_metric_ops)
    else:
      dummy_op = tf.no_op()
      # Need to call mixed precision graph rewrite if fp16 to enable graph rewrite
      if FLAGS.amp:
        dummy_op = tf.train.experimental.enable_mixed_precision_graph_rewrite(
            optimization.LAMBOptimizer(learning_rate=0.0))
      output_spec = tf.estimator.EstimatorSpec(
          mode=mode, predictions=probabilities)
    return output_spec
    def model_fn(features, labels, mode, params):
        tf.compat.v1.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.compat.v1.logging.info("  name = %s, shape = %s" % (name, features[name].shape))
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        # label_mask = features["label_mask"]
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        (total_loss, per_example_loss, logits, predicts) = create_model(
            bert_config, is_training, input_ids, input_mask, segment_ids, label_ids,
            num_labels, use_one_hot_embeddings)
        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint and (hvd is None or hvd.rank() == 0):
            (assignment_map,
             initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,
                                                                                       init_checkpoint)
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
        tf.compat.v1.logging.info("**** Trainable Variables ****")

        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.compat.v1.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)
        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(
                total_loss, learning_rate, num_train_steps, num_warmup_steps, hvd, False, amp)
            output_spec = tf.estimator.EstimatorSpec(
              mode=mode,
              loss=total_loss,
              train_op=train_op)
        elif mode == tf.estimator.ModeKeys.EVAL:
            dummy_op = tf.no_op()
            # Need to call mixed precision graph rewrite if fp16 to enable graph rewrite
            if amp:
                loss_scaler = tf.train.experimental.FixedLossScale(1)
                dummy_op = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                    optimization.LAMBOptimizer(learning_rate=0.0), loss_scaler)

            def metric_fn(per_example_loss, label_ids, logits):
                # def metric_fn(label_ids, logits):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                precision = tf_metrics.precision(label_ids, predictions, num_labels, [1, 2], average="macro")
                recall = tf_metrics.recall(label_ids, predictions, num_labels, [1, 2], average="macro")
                f = tf_metrics.f1(label_ids, predictions, num_labels, [1, 2], average="macro")
                #
                return {
                    "precision": precision,
                    "recall": recall,
                    "f1": f,
                }

            eval_metric_ops = metric_fn(per_example_loss, label_ids, logits)
            output_spec = tf.estimator.EstimatorSpec(
              mode=mode,
              loss=total_loss,
              eval_metric_ops=eval_metric_ops)
        else:

            dummy_op = tf.no_op()
            # Need to call mixed precision graph rewrite if fp16 to enable graph rewrite
            if amp:
                dummy_op = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                    optimization.LAMBOptimizer(learning_rate=0.0))

            output_spec = tf.estimator.EstimatorSpec(
              mode=mode, predictions=predicts)#probabilities)
        return output_spec