Exemplo n.º 1
0
  def _build_model(self, params, num_steps, is_training):
    """Builds the NCF model.

    Args:
      params: A dict of hyperparameters.
      is_training: If True, build the training model. If False, build the
        evaluation model.
    Returns:
      A _TrainModelProperties if is_training is True, or an _EvalModelProperties
      otherwise.
    """
    record_files_placeholder = tf.placeholder(tf.string, ())
    input_fn, _, _ = \
      data_preprocessing.make_input_fn(
          ncf_dataset=self._ncf_dataset, is_training=is_training,
          record_files=record_files_placeholder)
    dataset = input_fn(params)
    iterator = dataset.make_initializable_iterator()

    model_fn = neumf_model.neumf_model_fn
    if params["use_xla_for_gpu"]:
      model_fn = xla.estimator_model_fn(model_fn)

    if is_training:
      return self._build_train_specific_graph(
          iterator, model_fn, params, record_files_placeholder, num_steps)
    else:
      return self._build_eval_specific_graph(
          iterator, model_fn, params, record_files_placeholder, num_steps)
Exemplo n.º 2
0
def construct_estimator(model_dir, params):
    """Construct either an Estimator or TPUEstimator for NCF.

  Args:
    model_dir: The model directory for the estimator
    params: The params dict for the estimator

  Returns:
    An Estimator or TPUEstimator.
  """
    distribution = ncf_common.get_v1_distribution_strategy(params)
    run_config = tf.estimator.RunConfig(train_distribute=distribution,
                                        eval_distribute=distribution)

    model_fn = neumf_model.neumf_model_fn
    if params["use_xla_for_gpu"]:
        # TODO(seemuch): remove the contrib imput
        from tensorflow.contrib.compiler import xla
        logging.info("Using XLA for GPU for training and evaluation.")
        model_fn = xla.estimator_model_fn(model_fn)
    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       model_dir=model_dir,
                                       config=run_config,
                                       params=params)
    return estimator
Exemplo n.º 3
0
def construct_estimator(model_dir, params):
  """Construct either an Estimator or TPUEstimator for NCF.

  Args:
    model_dir: The model directory for the estimator
    params: The params dict for the estimator

  Returns:
    An Estimator or TPUEstimator.
  """

  if params["use_tpu"]:
    # Some of the networking libraries are quite chatty.
    for name in ["googleapiclient.discovery", "googleapiclient.discovery_cache",
                 "oauth2client.transport"]:
      logging.getLogger(name).setLevel(logging.ERROR)

    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        tpu=params["tpu"],
        zone=params["tpu_zone"],
        project=params["tpu_gcp_project"],
        coordinator_name="coordinator"
    )

    tf.logging.info("Issuing reset command to TPU to ensure a clean state.")
    tf.Session.reset(tpu_cluster_resolver.get_master())

    # Estimator looks at the master it connects to for MonitoredTrainingSession
    # by reading the `TF_CONFIG` environment variable, and the coordinator
    # is used by StreamingFilesDataset.
    tf_config_env = {
        "session_master": tpu_cluster_resolver.get_master(),
        "eval_session_master": tpu_cluster_resolver.get_master(),
        "coordinator": tpu_cluster_resolver.cluster_spec()
                       .as_dict()["coordinator"]
    }
    os.environ['TF_CONFIG'] = json.dumps(tf_config_env)

    distribution = tf.contrib.distribute.TPUStrategy(
        tpu_cluster_resolver, steps_per_run=100)

  else:
    distribution = distribution_utils.get_distribution_strategy(
        num_gpus=params["num_gpus"])

  run_config = tf.estimator.RunConfig(train_distribute=distribution,
                                      eval_distribute=distribution)

  model_fn = neumf_model.neumf_model_fn
  if params["use_xla_for_gpu"]:
    tf.logging.info("Using XLA for GPU for training and evaluation.")
    model_fn = xla.estimator_model_fn(model_fn)
  estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
                                     config=run_config, params=params)
  return estimator
Exemplo n.º 4
0
def construct_estimator(model_dir, params):
  """Construct either an Estimator or TPUEstimator for NCF.

  Args:
    model_dir: The model directory for the estimator
    params: The params dict for the estimator

  Returns:
    An Estimator or TPUEstimator.
  """

  if params["use_tpu"]:
    # Some of the networking libraries are quite chatty.
    for name in ["googleapiclient.discovery", "googleapiclient.discovery_cache",
                 "oauth2client.transport"]:
      logging.getLogger(name).setLevel(logging.ERROR)

    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        tpu=params["tpu"],
        zone=params["tpu_zone"],
        project=params["tpu_gcp_project"],
        coordinator_name="coordinator"
    )

    tf.logging.info("Issuing reset command to TPU to ensure a clean state.")
    tf.Session.reset(tpu_cluster_resolver.get_master())

    # Estimator looks at the master it connects to for MonitoredTrainingSession
    # by reading the `TF_CONFIG` environment variable, and the coordinator
    # is used by StreamingFilesDataset.
    tf_config_env = {
        "session_master": tpu_cluster_resolver.get_master(),
        "eval_session_master": tpu_cluster_resolver.get_master(),
        "coordinator": tpu_cluster_resolver.cluster_spec()
                       .as_dict()["coordinator"]
    }
    os.environ['TF_CONFIG'] = json.dumps(tf_config_env)

    distribution = tf.contrib.distribute.TPUStrategy(
        tpu_cluster_resolver, steps_per_run=100)

  else:
    distribution = distribution_utils.get_distribution_strategy(
        num_gpus=params["num_gpus"])

  run_config = tf.estimator.RunConfig(train_distribute=distribution,
                                      eval_distribute=distribution)

  model_fn = neumf_model.neumf_model_fn
  if params["use_xla_for_gpu"]:
    tf.logging.info("Using XLA for GPU for training and evaluation.")
    model_fn = xla.estimator_model_fn(model_fn)
  estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
                                     config=run_config, params=params)
  return estimator
Exemplo n.º 5
0
    def _build_model(self, params, is_training):
        """Builds the NCF model.

    Args:
      params: A dict of hyperparameters.
      is_training: If True, build the training model. If False, build the
        evaluation model.
    Returns:
      A _TrainModelProperties if is_training is True, or an _EvalModelProperties
      otherwise.
    """
        record_files_placeholder = tf.placeholder(tf.string, ())
        input_fn, _, _ = \
          data_preprocessing.make_input_fn(
              ncf_dataset=self._ncf_dataset, is_training=is_training,
              record_files=record_files_placeholder)
        dataset = input_fn(params)
        iterator = dataset.make_initializable_iterator()

        model_fn = neumf_model.neumf_model_fn
        if params["use_xla_for_gpu"]:
            model_fn = xla.estimator_model_fn(model_fn)

        if is_training:
            features, labels = iterator.get_next()
            estimator_spec = model_fn(features, labels,
                                      tf.estimator.ModeKeys.TRAIN, params)
            with tf.control_dependencies([estimator_spec.train_op]):
                run_model_op = self._global_step.assign_add(1)
            return self._TrainModelProperties(record_files_placeholder,
                                              iterator, estimator_spec.loss,
                                              params["batch_size"],
                                              run_model_op)
        else:
            features = iterator.get_next()
            estimator_spec = model_fn(features, None,
                                      tf.estimator.ModeKeys.EVAL, params)
            run_model_op = tf.group(
                *(update_op
                  for _, update_op in estimator_spec.eval_metric_ops.values()))
            metric_initializer = tf.variables_initializer(
                tf.get_collection(tf.GraphKeys.METRIC_VARIABLES))
            return self._EvalModelProperties(record_files_placeholder,
                                             iterator, estimator_spec.loss,
                                             params["eval_batch_size"],
                                             run_model_op,
                                             estimator_spec.eval_metric_ops,
                                             metric_initializer)
Exemplo n.º 6
0
def construct_estimator(model_dir, params):
  """Construct either an Estimator or TPUEstimator for NCF.

  Args:
    model_dir: The model directory for the estimator
    params: The params dict for the estimator

  Returns:
    An Estimator or TPUEstimator.
  """
  distribution = ncf_common.get_distribution_strategy(params)
  run_config = tf.estimator.RunConfig(train_distribute=distribution,
                                      eval_distribute=distribution)

  model_fn = neumf_model.neumf_model_fn
  if params["use_xla_for_gpu"]:
    tf.logging.info("Using XLA for GPU for training and evaluation.")
    model_fn = xla.estimator_model_fn(model_fn)
  estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
                                     config=run_config, params=params)
  return estimator
Exemplo n.º 7
0
def construct_estimator(num_gpus, model_dir, iterations, params, batch_size,
                        eval_batch_size):
  """Construct either an Estimator or TPUEstimator for NCF.

  Args:
    num_gpus: The number of gpus (Used to select distribution strategy)
    model_dir: The model directory for the estimator
    iterations:  Estimator iterations
    params: The params dict for the estimator
    batch_size: The mini-batch size for training.
    eval_batch_size: The batch size used during evaluation.

  Returns:
    An Estimator or TPUEstimator.
  """

  if params["use_tpu"]:
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        tpu=params["tpu"],
        zone=params["tpu_zone"],
        project=params["tpu_gcp_project"],
    )
    tf.logging.info("Issuing reset command to TPU to ensure a clean state.")
    tf.Session.reset(tpu_cluster_resolver.get_master())

    tpu_config = tf.contrib.tpu.TPUConfig(
        iterations_per_loop=iterations,
        num_shards=8)

    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=model_dir,
        save_checkpoints_secs=600,
        session_config=tf.ConfigProto(
            allow_soft_placement=True, log_device_placement=False),
        tpu_config=tpu_config)

    tpu_params = {k: v for k, v in params.items() if k != "batch_size"}

    train_estimator = tf.contrib.tpu.TPUEstimator(
        model_fn=neumf_model.neumf_model_fn,
        use_tpu=True,
        train_batch_size=batch_size,
        eval_batch_size=eval_batch_size,
        params=tpu_params,
        config=run_config)

    eval_estimator = tf.contrib.tpu.TPUEstimator(
        model_fn=neumf_model.neumf_model_fn,
        use_tpu=True,
        train_batch_size=1,
        eval_batch_size=eval_batch_size,
        params=tpu_params,
        config=run_config)

    return train_estimator, eval_estimator

  distribution = distribution_utils.get_distribution_strategy(num_gpus=num_gpus)
  run_config = tf.estimator.RunConfig(train_distribute=distribution,
                                      eval_distribute=distribution)
  params["eval_batch_size"] = eval_batch_size
  model_fn = neumf_model.neumf_model_fn
  if params["use_xla_for_gpu"]:
    tf.logging.info("Using XLA for GPU for training and evaluation.")
    model_fn = xla.estimator_model_fn(model_fn)
  estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
                                     config=run_config, params=params)
  return estimator, estimator
Exemplo n.º 8
0
def main(unused_argv):
    params = hyperparameters.get_hyperparameters(FLAGS.default_hparams_file,
                                                 FLAGS.hparams_file, FLAGS,
                                                 FLAGS.hparams)
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu if (FLAGS.tpu or params['use_tpu']) else '',
        zone=FLAGS.tpu_zone,
        project=FLAGS.gcp_project)

    if params['use_async_checkpointing']:
        save_checkpoints_steps = None
    else:
        save_checkpoints_steps = max(2500, params['iterations_per_loop'])
    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=get_model_dir(params),
        save_checkpoints_steps=save_checkpoints_steps,
        keep_checkpoint_max=None,  # Keep all checkpoints.
        log_step_count_steps=FLAGS.log_step_count_steps,
        session_config=tf.ConfigProto(
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True))),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=params['iterations_per_loop'],
            num_shards=params['num_cores'],
            # copybara:strip_begin
            tpu_job_name=FLAGS.tpu_job_name,
            # copybara:strip_end
            per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig
            .PER_HOST_V2))  # pylint: disable=line-too-long

    resnet_classifier = tf.contrib.tpu.TPUEstimator(
        use_tpu=params['use_tpu'],
        model_fn=resnet_model_fn,
        config=config,
        params=params,
        train_batch_size=params['train_batch_size'],
        eval_batch_size=params['eval_batch_size'],
        export_to_tpu=FLAGS.export_to_tpu)

    # copybara:strip_begin
    if FLAGS.xla_compile:
        resnet_classifier = tf.contrib.tpu.TPUEstimator(
            use_tpu=params['use_tpu'],
            model_fn=xla.estimator_model_fn(resnet_model_fn),
            config=config,
            params=params,
            train_batch_size=params['train_batch_size'],
            eval_batch_size=params['eval_batch_size'],
            export_to_tpu=FLAGS.export_to_tpu)
    # copybara:strip_end
    assert (params['precision'] == 'bfloat16' or params['precision']
            == 'float32'), ('Invalid value for precision parameter; '
                            'must be bfloat16 or float32.')
    tf.logging.info('Precision: %s', params['precision'])
    use_bfloat16 = params['precision'] == 'bfloat16'

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    if FLAGS.bigtable_instance:
        tf.logging.info('Using Bigtable dataset, table %s',
                        FLAGS.bigtable_table)
        select_train, select_eval = _select_tables_from_flags()
        imagenet_train = imagenet_input.ImageNetBigtableInput(
            is_training=True,
            use_bfloat16=use_bfloat16,
            transpose_input=params['transpose_input'],
            selection=select_train)
        imagenet_eval = imagenet_input.ImageNetBigtableInput(
            is_training=False,
            use_bfloat16=use_bfloat16,
            transpose_input=params['transpose_input'],
            selection=select_eval)
    else:
        if FLAGS.data_dir == FAKE_DATA_DIR:
            tf.logging.info('Using fake dataset.')
        else:
            tf.logging.info('Using dataset: %s', FLAGS.data_dir)
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data_dir,
                transpose_input=params['transpose_input'],
                cache=params['use_cache'] and is_training,
                image_size=params['image_size'],
                num_parallel_calls=params['num_parallel_calls'],
                use_bfloat16=use_bfloat16) for is_training in [True, False]
        ]

    steps_per_epoch = params['num_train_images'] // params['train_batch_size']
    eval_steps = params['num_eval_images'] // params['eval_batch_size']

    if FLAGS.mode == 'eval':

        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(
                get_model_dir(params), timeout=FLAGS.eval_timeout):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=eval_steps,
                    checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                                eval_results, elapsed_time)

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= params['train_steps']:
                    tf.logging.info(
                        'Evaluation finished after training step %d',
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)

    elif FLAGS.mode == 'eval_igt':
        # IGT evaluation mode. Evaluate metrics for the desired parameters
        # (true or shifted) on the desired dataset (train or eval). Note that
        # train is still with data augmentation.

        # Get checkpoint file names.
        index_files = tf.gfile.Glob(
            os.path.join(get_model_dir(params), 'model.ckpt-*.index'))
        checkpoints = [fn[:-len('.index')] for fn in index_files]
        # Need to sort them to get proper tensorboard plotting (increasing event
        # timestamps correspond to increasing steps).
        checkpoint_steps = []
        for ckpt in checkpoints:
            tf.logging.info(ckpt)
            step_match = re.match(r'.*model.ckpt-([0-9]*)', ckpt)
            checkpoint_steps.append(int(step_match.group(1)))
        checkpoints = [
            ckpt for _, ckpt in sorted(zip(checkpoint_steps, checkpoints))
        ]
        tf.logging.info('There are {} checkpoints'.format(len(checkpoints)))
        tf.logging.info(', '.join(checkpoints))

        # Keep track of the last processed checkpoint (fault tolerance).
        analysis_state_path = os.path.join(
            get_model_dir(params),
            'analysis_state_' + FLAGS.igt_eval_set + '_' + FLAGS.igt_eval_mode)
        next_analysis_index = 0
        if tf.gfile.Exists(analysis_state_path):
            with tf.gfile.Open(analysis_state_path) as fd:
                next_analysis_index = int(fd.read())

        # Process each checkpoint.
        while next_analysis_index < len(checkpoints):
            tf.logging.info(
                'Next analysis index: {}'.format(next_analysis_index))
            ckpt_path = checkpoints[next_analysis_index]
            tf.logging.info('Starting to evaluate: {}.'.format(ckpt_path))
            start_timestamp = time.time(
            )  # This time will include compilation time

            if FLAGS.igt_eval_set == 'train':
                the_input_fn = imagenet_train.input_fn
                the_steps = steps_per_epoch
            elif FLAGS.igt_eval_set == 'eval':
                the_input_fn = imagenet_eval.input_fn
                the_steps = eval_steps
            else:
                raise ValueError('Unsupported igt_eval_set')

            eval_results = resnet_classifier.evaluate(
                input_fn=the_input_fn,
                steps=the_steps,
                checkpoint_path=ckpt_path,
                name=FLAGS.igt_eval_set + '_' + FLAGS.igt_eval_mode)
            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                            eval_results, elapsed_time)

            next_analysis_index += 1
            file_io.atomic_write_string_to_file(analysis_state_path,
                                                str(next_analysis_index))

    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(
            get_model_dir(params))  # pylint:disable=protected-access,g-line-too-long
        steps_per_epoch = params['num_train_images'] // params[
            'train_batch_size']
        tf.logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', params['train_steps'],
            params['train_steps'] / steps_per_epoch, current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            hooks = []
            if params['use_async_checkpointing']:
                hooks.append(
                    async_checkpoint.AsyncCheckpointSaverHook(
                        checkpoint_dir=get_model_dir(params),
                        save_steps=max(2500, params['iterations_per_loop'])))
            resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                    max_steps=params['train_steps'],
                                    hooks=hooks)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < params['train_steps']:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      params['train_steps'])
                resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                        max_steps=next_checkpoint)
                current_step = next_checkpoint

                tf.logging.info(
                    'Finished training up to step %d. Elapsed seconds %d.',
                    next_checkpoint, int(time.time() - start_timestamp))

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be excluded modulo the batch size. As long as the batch size is
                # consistent, the evaluated images are also consistent.
                tf.logging.info('Starting to evaluate.')
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=params['num_eval_images'] //
                    params['eval_batch_size'])
                tf.logging.info('Eval results at step %d: %s', next_checkpoint,
                                eval_results)

            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                params['train_steps'], elapsed_time)

        if FLAGS.export_dir is not None:
            # The guide to serve a exported TensorFlow model is at:
            #    https://www.tensorflow.org/serving/serving_basic
            tf.logging.info('Starting to export model.')
            unused_export_path = resnet_classifier.export_saved_model(
                export_dir_base=FLAGS.export_dir,
                serving_input_receiver_fn=imagenet_input.image_serving_input_fn
            )
Exemplo n.º 9
0
def construct_estimator(num_gpus, model_dir, params, batch_size,
                        eval_batch_size):
    """Construct either an Estimator or TPUEstimator for NCF.

  Args:
    num_gpus: The number of gpus (Used to select distribution strategy)
    model_dir: The model directory for the estimator
    params: The params dict for the estimator
    batch_size: The mini-batch size for training.
    eval_batch_size: The batch size used during evaluation.

  Returns:
    An Estimator or TPUEstimator.
  """

    if params["use_tpu"]:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            tpu=params["tpu"],
            zone=params["tpu_zone"],
            project=params["tpu_gcp_project"],
        )
        tf.logging.info(
            "Issuing reset command to TPU to ensure a clean state.")
        tf.Session.reset(tpu_cluster_resolver.get_master())

        tpu_config = tf.contrib.tpu.TPUConfig(iterations_per_loop=100,
                                              num_shards=8)

        run_config = tf.contrib.tpu.RunConfig(cluster=tpu_cluster_resolver,
                                              model_dir=model_dir,
                                              session_config=tf.ConfigProto(
                                                  allow_soft_placement=True,
                                                  log_device_placement=False),
                                              tpu_config=tpu_config)

        tpu_params = {k: v for k, v in params.items() if k != "batch_size"}

        train_estimator = tf.contrib.tpu.TPUEstimator(
            model_fn=neumf_model.neumf_model_fn,
            use_tpu=True,
            train_batch_size=batch_size,
            params=tpu_params,
            config=run_config)

        eval_estimator = tf.contrib.tpu.TPUEstimator(
            model_fn=neumf_model.neumf_model_fn,
            use_tpu=False,
            train_batch_size=1,
            eval_batch_size=eval_batch_size,
            params=tpu_params,
            config=run_config)

        return train_estimator, eval_estimator

    distribution = distribution_utils.get_distribution_strategy(
        num_gpus=num_gpus)
    run_config = tf.estimator.RunConfig(train_distribute=distribution,
                                        eval_distribute=distribution)
    params["eval_batch_size"] = eval_batch_size
    model_fn = neumf_model.neumf_model_fn
    if params["use_xla_for_gpu"]:
        tf.logging.info("Using XLA for GPU for training and evaluation.")
        model_fn = xla.estimator_model_fn(model_fn)
    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       model_dir=model_dir,
                                       config=run_config,
                                       params=params)
    return estimator, estimator
Exemplo n.º 10
0
class XlaDecoratorTest(test.TestCase, parameterized.TestCase):
    @parameterized.named_parameters(
        ('test_use_as_decorator', decorated_model_fn, None),
        ('test_use_as_function', xla.estimator_model_fn(_test_train_model_fn),
         None),
        ('test_use_tpu_false_hparams', decorated_model_fn,
         hparam.HParams(use_tpu=False)),
        ('test_use_tpu_false_dict_params', decorated_model_fn, {
            'use_tpu': False
        }),
    )
    def test_compile(self, model_fn, params):
        """Calls model_fn and verifies it is compiled."""
        with test.mock.patch.object(xla, 'compile') as mock_xla_compile:
            loss = constant_op.constant(_EXPECTED_LOSS)
            mock_xla_compile.return_value = [loss]

            features, labels = make_dummy_features_labels()
            estimator_spec = model_fn(features=features,
                                      labels=labels,
                                      mode=_TRAIN,
                                      params=params or {})

            self.assertEqual(mock_xla_compile.call_count, 1)
            self.assertEqual(estimator_spec.mode, _TRAIN)

            with self.test_session() as sess:
                self.assertEqual(sess.run(estimator_spec.loss), sess.run(loss))
                self.assertEqual(sess.run(estimator_spec.train_op),
                                 sess.run(loss))

    @parameterized.named_parameters(
        ('test_use_tpu_true_hparams', decorated_model_fn,
         hparam.HParams(use_tpu=True)),
        ('test_use_tpu_true_dict_params', decorated_model_fn, {
            'use_tpu': True
        }),
    )
    def test_not_compile(self, model_fn, params):
        """Calls model_fn and verifies it is NOT compiled."""
        with test.mock.patch.object(xla, 'compile') as mock_xla_compile:
            loss = constant_op.constant(_EXPECTED_LOSS)
            mock_xla_compile.return_value = [loss]

            features, labels = make_dummy_features_labels()
            estimator_spec = model_fn(features=features,
                                      labels=labels,
                                      mode=_TRAIN,
                                      params=params or {})

            mock_xla_compile.assert_not_called()
            self.assertEqual(estimator_spec.mode, _TRAIN)

            with self.test_session() as sess:
                self.assertEqual(sess.run(estimator_spec.loss), sess.run(loss))
                self.assertEqual(sess.run(estimator_spec.train_op),
                                 sess.run(loss))

    def test_model_with_summary(self):
        """Tests that summary ops are disabled."""
        @xla.estimator_model_fn
        def model_fn_with_summary(features, labels, mode, params):
            del features, labels, params
            loss = constant_op.constant(_EXPECTED_LOSS)
            summary.scalar('loss_scalar_summary', loss)
            summary.histogram('loss_histogram_summary', loss)
            summary.image('loss_image_summary', loss)
            return model_fn_lib.EstimatorSpec(
                mode=mode, loss=loss, train_op=array_ops.identity(loss))

        features, labels = make_dummy_features_labels()
        estimator_spec = model_fn_with_summary(features=features,
                                               labels=labels,
                                               mode=_TRAIN,
                                               params={})

        with self.test_session() as sess:
            self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS)