def model_fn(features, labels, mode, params):
    """The `model_fn` for TPUEstimator."""
    tf.logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    num_gpus = n_gpus
    if is_training:
      optimizer = optimization.create_optimizer_mgpu(learning_rate, num_train_steps, num_warmup_steps)
    else:
      num_gpus=1

    input_ids_list = tf.split(features["input_ids"], num_or_size_splits=num_gpus, axis=0)
    input_mask_list = tf.split(features["input_mask"], num_or_size_splits=num_gpus, axis=0)
    segment_ids_list = tf.split(features["segment_ids"], num_or_size_splits=num_gpus, axis=0)
    label_ids_list = tf.split(features["label_ids"], num_or_size_splits=num_gpus, axis=0)

    tower_grads = []
    train_perplexity = 0
    for index in range(num_gpus):
      with tf.name_scope('replica_%d' % index):
        with tf.device('/gpu:%d' % index):
          (total_loss, per_example_loss, logits) = create_model(
              bert_config, is_training,
              input_ids_list[index], input_mask_list[index], segment_ids_list[index], label_ids_list[index],
              num_labels, use_one_hot_embeddings)

          tvars = tf.trainable_variables()

          scaffold_fn = None
          if init_checkpoint:
            (assignment_map,
             initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            for var in tvars:
              param_name = var.name[:-2]
              tf.get_variable(
                name=param_name + "/adam_m",
                shape=var.shape.as_list(),
                dtype=tf.float32,
                trainable=False,
                initializer=tf.zeros_initializer())
              tf.get_variable(
                name=param_name + "/adam_v",
                shape=var.shape.as_list(),
                dtype=tf.float32,
                trainable=False,
                initializer=tf.zeros_initializer())
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""

    tf.logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    num_gpus = n_gpus
    if is_training:
      optimizer = optimization.create_optimizer_mgpu(learning_rate, num_train_steps, num_warmup_steps)
    else:
      num_gpus = 1

    input_ids_list = tf.split(features["input_ids"], num_or_size_splits=num_gpus, axis=0)
    input_mask_list = tf.split(features["input_mask"], num_or_size_splits=num_gpus, axis=0)
    segment_ids_list = tf.split(features["segment_ids"], num_or_size_splits=num_gpus, axis=0)
    masked_lm_positions_list = tf.split(features["masked_lm_positions"], num_or_size_splits=num_gpus, axis=0)
    masked_lm_ids_list = tf.split(features["masked_lm_ids"], num_or_size_splits=num_gpus, axis=0)
    masked_lm_weights_list = tf.split(features["masked_lm_weights"], num_or_size_splits=num_gpus, axis=0)

    tower_grads = []
    train_perplexity = 0
    for index in range(num_gpus):
      with tf.name_scope('replica_%d' % index):
        with tf.device('/gpu:%d' % index):
          model = modeling.BertModel(
              config=bert_config,
              is_training=is_training,
              input_ids=input_ids_list[index],
              input_mask=input_mask_list[index],
              token_type_ids=segment_ids_list[index],
              use_one_hot_embeddings=use_one_hot_embeddings)

          (masked_lm_loss,
           masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
               bert_config, model.get_sequence_output(), model.get_embedding_table(),
               masked_lm_positions_list[index], masked_lm_ids_list[index], masked_lm_weights_list[index])

          total_loss = masked_lm_loss

          tvars = tf.trainable_variables()

          scaffold_fn = None
          initialized_variable_names = {}
          # if init_checkpoint and index == 0:
          if init_checkpoint and index == 0:
            (assignment_map,
             initialized_variable_names) = modeling.get_assigment_map_from_checkpoint(
                 tvars, init_checkpoint)
            for var in tvars:
              param_name = var.name[:-2]
              tf.get_variable(
                name=param_name + "/adam_m",
                shape=var.shape.as_list(),
                dtype=tf.float32,
                trainable=False,
                initializer=tf.zeros_initializer())
              tf.get_variable(
                name=param_name + "/adam_v",
                shape=var.shape.as_list(),
                dtype=tf.float32,
                trainable=False,
                initializer=tf.zeros_initializer())

            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

            tf.logging.info("**** Trainable Variables ****")
            for var in tvars:
              init_string = ""
              if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
              tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                              init_string)
          if is_training:
            # reuse variables
            tf.get_variable_scope().reuse_variables()
            loss = total_loss
            # get gradients
            grads = optimizer.compute_gradients(
              loss,
              aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE,
            )
            tower_grads.append(grads)
            # keep track of loss across all GPUs
            train_perplexity += loss

    if mode == tf.estimator.ModeKeys.TRAIN:
      global_step = tf.train.get_or_create_global_step()
      new_global_step = global_step + 1

      average_grads = average_gradients(tower_grads, None, None)
      average_grads, norm_summary_ops = clip_grads(average_grads, 1.0, True, global_step)

      train_op = optimizer.apply_gradients(average_grads)
      train_op = tf.group(train_op, [global_step.assign(new_global_step)])

      output_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=train_perplexity / num_gpus,
          train_op=train_op,
          scaffold_fn=scaffold_fn)

    elif mode == tf.estimator.ModeKeys.EVAL:
      def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                    masked_lm_weights):
        """Computes the loss and accuracy of the model."""
        masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
                                         [-1, masked_lm_log_probs.shape[-1]])
        masked_lm_predictions = tf.argmax(
            masked_lm_log_probs, axis=-1, output_type=tf.int32)
        masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
        masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
        masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
        masked_lm_accuracy = tf.metrics.accuracy(
            labels=masked_lm_ids,
            predictions=masked_lm_predictions,
            weights=masked_lm_weights)
        masked_lm_mean_loss = tf.metrics.mean(
            values=masked_lm_example_loss, weights=masked_lm_weights)

        return {
            "masked_lm_accuracy": masked_lm_accuracy,
            "masked_lm_loss": masked_lm_mean_loss,
        }

      eval_metrics = (metric_fn, [
          masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids_list[0],
          masked_lm_weights_list[0]])
      output_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=total_loss,
          eval_metrics=eval_metrics,
          scaffold_fn=scaffold_fn)
    else:
      raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))

    return output_spec
def bertmodel(bert_config,bert_init_checkpoint,learning_rate,num_train_steps,num_warmup_steps,use_one_hot_embeddings,features,ngpus,is_training):

  num_gpu = ngpus if is_training else 1
  #if is_training:
  optimizer = optimization.create_optimizer_mgpu(learning_rate, num_train_steps, num_warmup_steps)
  full_query_id = features['query_id']
  full_product_id = features['product_id']

  # for key in features:features[key] = tf.squeeze(features[key],axis=1)
  input_ids = tf.split(features["input_ids"], num_or_size_splits=num_gpu, axis=0)
  # input_mask = features["input_mask"]
  segment_ids = tf.split(features["segment_ids"], num_or_size_splits=num_gpu, axis=0)
  # masked_lm_positions = features["masked_lm_positions"]
  # masked_lm_ids = features["masked_lm_ids"]
  # masked_lm_weights = features["masked_lm_weights"]
  boxes = tf.split(features['boxes'], num_or_size_splits=num_gpu, axis=0)
  boxfeat = tf.split(features['features'], num_or_size_splits=num_gpu, axis=0)
  labelfeat = tf.split(features["labelfeat"], num_or_size_splits=num_gpu, axis=0)
  query_id = tf.split(features["query_id"], num_or_size_splits=num_gpu, axis=0)
  product_id = tf.split(features["product_id"], num_or_size_splits=num_gpu, axis=0)

  next_sentence_labels = tf.split(features["next_sentence_labels"], num_or_size_splits=num_gpu, axis=0)



  tower_grads = []
  train_perplexity = 0
  next_sentence_gpu = 0
  next_sentence_acc_gpus = 0
  next_sentence_op_gpus = 0
  empty = True
  for gpuid in range(num_gpu,num_gpu+1):
    """
    use gpu:1
    """
    with tf.device('/gpu:%d' % gpuid):
      with tf.name_scope('multigpu%d' % gpuid):
        gpuid -= 1
        model = pixelmodel.BertModel(imgfeat=boxfeat[gpuid],
                                     config=bert_config,
                                     is_training=False,
                                     input_ids=input_ids[gpuid],
                                     label_ids = labelfeat[gpuid],
                                     token_type_ids=segment_ids[gpuid],
                                     use_one_hot_embeddings=use_one_hot_embeddings, random_sample=FLAGS.random_sample)

        (next_sentence_loss, next_sentence_example_loss,
         next_sentence_log_probs, next_sentence_probs) = get_next_sentence_output(
          bert_config, model.get_pooled_output(), next_sentence_labels[gpuid])
        #total_loss =  next_sentence_loss

        tvars = tf.trainable_variables()


        initialized_variable_names = {}
        scaffold_fn = None
        if bert_init_checkpoint and gpuid==0:
          (assignment_map, initialized_variable_names
           ) = pixelmodel.get_assignment_map_from_checkpoint(tvars, bert_init_checkpoint)
          for var in tvars:
            param_name = var.name[:-2]
            tf.get_variable(
              name=param_name + "/adam_m",
              shape=var.shape.as_list(),
              dtype=tf.float32,
              trainable=False,
              initializer=tf.zeros_initializer())
            tf.get_variable(
              name=param_name + "/adam_v",
              shape=var.shape.as_list(),
              dtype=tf.float32,
              trainable=False,
              initializer=tf.zeros_initializer())

          tf.train.init_from_checkpoint(bert_init_checkpoint, assignment_map)

        tf.get_variable_scope().reuse_variables()

        next_sentence_log_probs = tf.reshape(
          next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
        next_sentence_predictions = tf.argmax(
          next_sentence_log_probs, axis=-1, output_type=tf.int64)

        #next_sentence_labels[gpuid] = tf.reshape(next_sentence_labels[gpuid], [-1])
        # print(next_sentence_labels[gpuid])
        # next_sentence_labels_expand = tf.expand_dims(next_sentence_labels[gpuid],-1)
        #next_sentence_accuracy = tf.metrics.accuracy(
        #  labels=next_sentence_labels[gpuid], predictions=next_sentence_predictions)

        #loss = total_loss

        #grads = optimizer.compute_gradients(
        #  loss, var_list=tvars,
        #  aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE
        #)
        #tower_grads.append(grads)
        # keep track of loss across all GPUs
        #train_perplexity += loss
        #next_sentence_gpu += next_sentence_loss
        #next_sentence_op_gpus += next_sentence_accuracy[0]
        #next_sentence_acc_gpus += next_sentence_accuracy[1]
        if empty:
          next_sentence_prob_gpus = next_sentence_probs
          empty = False
        else:next_sentence_prob_gpus = tf.concat((next_sentence_prob_gpus,next_sentence_probs),axis=0)

  if not is_training:return next_sentence_prob_gpus

  global_step = tf.train.get_or_create_global_step()
  new_global_step = global_step + 1

  average_grads = average_gradients(tower_grads, None, None)
  average_grads, norm_summary_ops = clip_grads(average_grads, 1.0, True, global_step)

  train_op = optimizer.apply_gradients(average_grads)
  train_op = tf.group(train_op, [global_step.assign(new_global_step)])
  #
  # masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
  #                                  [-1, masked_lm_log_probs.shape[-1]])
  # masked_lm_predictions = tf.argmax(
  #   masked_lm_log_probs, axis=-1, output_type=tf.int64)
  # masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
  # masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
  # masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
  # masked_lm_accuracy_val,masked_lm_accuracy_op = tf.metrics.accuracy(
  #   labels=masked_lm_ids,
  #   predictions=masked_lm_predictions,
  #   weights=masked_lm_weights)
  # masked_lm_mean_loss = tf.metrics.mean(
  #   values=masked_lm_example_loss, weights=masked_lm_weights)
  #
  # next_sentence_log_probs = tf.reshape(
  #   next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
  # next_sentence_predictions = tf.argmax(
  #   next_sentence_log_probs, axis=-1, output_type=tf.int64)
  #
  # next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
  # next_sentence_accuracy = tf.metrics.accuracy(
  #   labels=next_sentence_labels, predictions=next_sentence_predictions)


  return train_op,train_perplexity/num_gpu,next_sentence_gpu/num_gpu,next_sentence_op_gpus/num_gpu,next_sentence_acc_gpus/num_gpu,next_sentence_prob_gpus,full_query_id,full_product_id
  def model_fn(features, labels, mode, params):
    """The `model_fn` for TPUEstimator."""
    tf.logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    num_gpus = n_gpus
    if is_training:
      optimizer = optimization.create_optimizer_mgpu(learning_rate, num_train_steps, num_warmup_steps)
    else:
      num_gpus=1

    input_ids_list = tf.split(features["input_ids"], num_or_size_splits=num_gpus, axis=0)
    input_mask_list = tf.split(features["input_mask"], num_or_size_splits=num_gpus, axis=0)
    segment_ids_list = tf.split(features["segment_ids"], num_or_size_splits=num_gpus, axis=0)
    label_ids_list = tf.split(features["label_ids"], num_or_size_splits=num_gpus, axis=0)

    tower_grads = []
    train_perplexity = 0
    for index in range(num_gpus):
      with tf.name_scope('replica_%d' % index):
        with tf.device('/gpu:%d' % index):
          (total_loss, per_example_loss, logits) = create_model(
              bert_config, is_training,
              input_ids_list[index], input_mask_list[index], segment_ids_list[index], label_ids_list[index],
              num_labels, use_one_hot_embeddings)

          tvars = tf.trainable_variables()

          scaffold_fn = None
          if init_checkpoint:
            (assignment_map,
             initialized_variable_names) = modeling.get_assigment_map_from_checkpoint(
                 tvars, init_checkpoint)
            for var in tvars:
              param_name = var.name[:-2]
              tf.get_variable(
                name=param_name + "/adam_m",
                shape=var.shape.as_list(),
                dtype=tf.float32,
                trainable=False,
                initializer=tf.zeros_initializer())
              tf.get_variable(
                name=param_name + "/adam_v",
                shape=var.shape.as_list(),
                dtype=tf.float32,
                trainable=False,
                initializer=tf.zeros_initializer())
            if use_tpu:
              def tpu_scaffold():
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
                return tf.train.Scaffold()

              scaffold_fn = tpu_scaffold
            else:
              tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

          tf.logging.info("**** Trainable Variables ****")
          tf.logging.info('device: %d init' % index)
          if index == 0:
            for var in tvars:
              init_string = ""
              if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
              tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                              init_string)
          if is_training:
            # reuse variables
            tf.get_variable_scope().reuse_variables()
            loss = total_loss
            # get gradients
            grads = optimizer.compute_gradients(
              loss,
              aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE,
            )
            tower_grads.append(grads)
            # keep track of loss across all GPUs
            train_perplexity += loss

    if mode == tf.estimator.ModeKeys.TRAIN:
      global_step = tf.train.get_or_create_global_step()
      new_global_step = global_step + 1

      average_grads = average_gradients(tower_grads, None, None)
      #average_grads, norm_summary_ops = clip_grads(average_grads, 1.0, True, global_step)
      train_op = optimizer.apply_gradients(average_grads)
      train_op = tf.group(train_op, [global_step.assign(new_global_step)])
      output_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=train_perplexity / n_gpus,
          train_op=train_op,
          scaffold_fn=scaffold_fn)
    elif mode == tf.estimator.ModeKeys.PREDICT:
      predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
      output_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode,
          predictions={
              'predictions': predictions,
          })
    elif mode == tf.estimator.ModeKeys.EVAL:
      def metric_fn(per_example_loss, label_ids, logits):
        predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
        accuracy = tf.metrics.accuracy(label_ids, predictions)
        loss = tf.metrics.mean(per_example_loss)
        return {
            "eval_accuracy": accuracy,
            "eval_loss": loss,
        }

      eval_metrics = (metric_fn, [per_example_loss, label_ids_list[0], logits])
      output_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=total_loss,
          eval_metrics=eval_metrics,
          scaffold_fn=scaffold_fn)
    else:
      raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))

    return output_spec