コード例 #1
0
def main(_):
    os.environ["TFHUB_CACHE_DIR"] = os.path.join(FLAGS.model_dir,
                                                 "tfhub_modules")
    os.environ["TFHUB_DOWNLOAD_PROGRESS"] = "True"
    input_fn = partial(input_, iterations=FLAGS.iterations)
    cluster = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    run_config = tpu_config.RunConfig(model_dir=FLAGS.model_dir,
                                      cluster=cluster,
                                      tpu_config=tpu_config.TPUConfig(
                                          FLAGS.iterations))

    classifier = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                            use_tpu=FLAGS.use_tpu,
                                            train_batch_size=FLAGS.batch_size,
                                            eval_batch_size=FLAGS.batch_size,
                                            config=run_config,
                                            params={
                                                "use_tpu": FLAGS.use_tpu,
                                                "data_dir": FLAGS.data_dir,
                                                "dataset": FLAGS.dataset
                                            })

    classifier.train(input_fn=lambda params: input_fn(
        mode=tf.estimator.ModeKeys.TRAIN, **params),
                     max_steps=2000)  #.evaluate(
コード例 #2
0
 def _make_task_estimator(task):
     task_params = params.copy()
     task_params["eval_task"] = task
     return tpu_estimator.TPUEstimator(
         use_tpu=params["use_tpu"],
         model_fn=model_fn,
         config=config,
         train_batch_size=params["train_batch_size"],
         eval_batch_size=params["train_batch_size"],
         predict_batch_size=params["predict_batch_size"],
         params=task_params)
コード例 #3
0
  def testAsyncCheckpointHookEnabled(self):
    resolver = tpu_cluster_resolver.TPUClusterResolver(
        tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)

    checkpoint_interval = 5
    config = tpu_config.RunConfig(
        master=resolver.master(),
        model_dir=os.path.join(FLAGS.model_dir, 'runconfig'),
        save_checkpoints_steps=1000,
        keep_checkpoint_max=11,  # off by one
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=checkpoint_interval,))

    estimator = tpu_estimator.TPUEstimator(
        use_tpu=True,
        model_fn=model_fn,
        config=config,
        train_batch_size=32,
        eval_batch_size=32,
        predict_batch_size=1,
        params={},
    )

    i = 10
    mock_listener = test.mock.create_autospec(
        basic_session_run_hooks.CheckpointSaverListener)
    estimator.train(
        input_fn=input_fn,
        max_steps=i * 10,
        hooks=[
            async_checkpoint.AsyncCheckpointSaverHook(
                FLAGS.model_dir,
                save_steps=checkpoint_interval,
                listeners=[mock_listener])
        ])

    current_step = estimator_lib._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access

    # TODO(power) -- identify a better way to count the number of checkpoints.
    checkpoints = file_io.get_matching_files(
        FLAGS.model_dir + '/model.ckpt*.meta')
    checkpoint_count = len(checkpoints)
    logging.info('Found %d checkpoints: %s', checkpoint_count, checkpoints)
    self.assertLessEqual(checkpoint_count, 10)
    self.assertEqual(current_step, i * 10)
    mock_listener.before_save.assert_called()
    mock_listener.after_save.assert_called()
コード例 #4
0
def main(_):
    os.environ["TFHUB_CACHE_DIR"] = os.path.join(FLAGS.model_dir,
                                                 "tfhub_modules")
    os.environ["TFHUB_DOWNLOAD_PROGRESS"] = "True"
    input_fn = partial(input_, iterations=FLAGS.iterations)
    cluster = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    run_config = tpu_config.RunConfig(model_dir=FLAGS.model_dir,
                                      cluster=cluster,
                                      tpu_config=tpu_config.TPUConfig(
                                          FLAGS.iterations))

    classifier = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                            use_tpu=FLAGS.use_tpu,
                                            train_batch_size=FLAGS.batch_size,
                                            eval_batch_size=FLAGS.batch_size,
                                            config=run_config,
                                            params={
                                                "use_tpu":
                                                FLAGS.use_tpu,
                                                "data_dir":
                                                FLAGS.data_dir,
                                                "dataset":
                                                FLAGS.dataset,
                                                "use_compat":
                                                FLAGS.use_compat,
                                                "learning_rate":
                                                FLAGS.learning_rate
                                            })
    try:
        classifier.train(input_fn=lambda params: input_fn(
            mode=tf.estimator.ModeKeys.TRAIN, **params),
                         max_steps=FLAGS.max_steps)
    except Exception:
        pass
    if FLAGS.infer:

        def prepare_input_fn(path):
            img = tf.image.decode_image(tf.io.read_file(path))
            return resize_and_scale(img, None)

        predictions = classifer.predict(
            input_fn=lambda params: prepare_input_fn(FLAGS.infer))
        print(predictions)
コード例 #5
0
ファイル: mesh_lpt_TPU.py プロジェクト: ml-lab/flowpm
def main(_):

  tf.logging.set_verbosity(tf.logging.INFO)
  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)

  # Resolve the TPU environment
  tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
      FLAGS.tpu,
      zone=FLAGS.tpu_zone,
      project=FLAGS.gcp_project
  )

  run_config = tf.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      save_checkpoints_steps=None,  # Disable the default saver
      save_checkpoints_secs=None,  # Disable the default saver
      log_step_count_steps=100,
      save_summary_steps=100,
      tpu_config=tpu_config.TPUConfig(
          num_shards=mesh_shape.size,
          iterations_per_loop=100,
          num_cores_per_replica=1,
          per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST))

  model = tpu_estimator.TPUEstimator(
      use_tpu=True,
      model_fn=model_fn,
      config=run_config,
      predict_batch_size=1,
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=FLAGS.batch_size)

  def dummy_input_fn(params):
    dset = tf.data.Dataset.from_tensor_slices(tf.zeros(shape=[params['batch_size'],1],
                                                       dtype=tf.float32))
    return dset

  # Run evaluate loop for ever, we will be connecting to this process using a profiler
  for i, f in enumerate(model.predict(input_fn=dummy_input_fn)):
    print(i)
    np.save(file_io.FileIO(FLAGS.output_dir+'/field_%d.npy'%i, 'w'), f['field'])
コード例 #6
0
def run_toy_model_tpu():
    """Run a toy model on TPU."""
    tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    iterations_per_loop = FLAGS.iterations
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    config = tpu_config.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=None,  # Disable the default saver
        save_checkpoints_secs=None,  # Disable the default saver
        log_step_count_steps=iterations_per_loop,
        save_summary_steps=iterations_per_loop,
        tpu_config=tpu_config.TPUConfig(
            num_shards=mesh_shape.size,
            iterations_per_loop=iterations_per_loop,
            num_cores_per_replica=1,
            per_host_input_for_training=tpu_config.InputPipelineConfig.
            BROADCAST))
    classifier = tpu_estimator.TPUEstimator(use_tpu=True,
                                            model_fn=model_fn,
                                            config=config,
                                            train_batch_size=FLAGS.batch_size,
                                            eval_batch_size=FLAGS.batch_size)
    current_step = estimator_lib._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    logging.info('Current step %d', current_step)
    if FLAGS.steps_per_checkpoint == 0:
        classifier.train(input_fn=ToyModelInput(), max_steps=FLAGS.train_steps)
        return
    while current_step < FLAGS.train_steps:
        next_checkpoint = min(current_step + FLAGS.steps_per_checkpoint,
                              FLAGS.train_steps)
        classifier.train(input_fn=ToyModelInput(), max_steps=next_checkpoint)
        current_step = next_checkpoint
        logging.info('Starting to evaluate.')
        eval_results = classifier.evaluate(
            input_fn=ToyModelInput(), steps=156
        )  # since we have 10000 examples and batch_size = 64 per host
        logging.info('Eval results: %s', eval_results)
コード例 #7
0
def main(_):

  tf.logging.set_verbosity(tf.logging.INFO)
  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)

  # Resolve the TPU environment
  tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
      FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  run_config = tf.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      save_checkpoints_steps=None,  # Disable the default saver
      save_checkpoints_secs=None,  # Disable the default saver
      log_step_count_steps=100,
      save_summary_steps=100,
      tpu_config=tpu_config.TPUConfig(
          num_shards=mesh_shape.size,
          iterations_per_loop=100,
          num_cores_per_replica=1,
          per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST))

  model = tpu_estimator.TPUEstimator(
      use_tpu=True,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=FLAGS.batch_size)

  def dummy_input_fn(params):
    """Dummy input function """
    return tf.zeros(
        shape=[params['batch_size']], dtype=tf.float32), tf.zeros(
            shape=[params['batch_size']], dtype=tf.float32)

  # Run evaluate loop for ever, we will be connecting to this process using a profiler
  model.evaluate(input_fn=dummy_input_fn, steps=10000)
コード例 #8
0
def main():
    # parse args and params
    args = parse_args()
    logging = setup_logging(args)
    params = fetch_model_params(args.model)
    assert params["model_type"].lower(
    ) == "vae", f'model_type {params["model_type"]} not recognized'

    # Confirm deletion of checkpoint files if --new flag is set
    if args.new:
        maybe_remove_gs_or_filepath(params["model_path"])

    # get current step
    current_step = int(
        estimator_lib._load_global_step_from_checkpoint_dir(
            params["model_path"]))
    logging.info(f"Current step: {current_step}")

    # Add to params:
    params["use_tpu"] = True if not args.tpu is None else False
    params["gpu_ids"] = args.gpu_ids

    # Set up TPUs and Estimator
    if args.tpu == "colab":
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        ) if params["use_tpu"] else None
    else:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            args.tpu) if params["use_tpu"] else None

    config = tpu_config.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=params["model_path"],
        save_checkpoints_steps=params["steps_per_checkpoint"],
        log_step_count_steps=params["iterations"],
        save_summary_steps=params["iterations"],
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=params["iterations"],
            num_cores_per_replica=1,
            experimental_host_call_every_n_steps=100,
            per_host_input_for_training=tpu_config.InputPipelineConfig.
            BROADCAST))

    estimator = tpu_estimator.TPUEstimator(
        use_tpu=params["use_tpu"],
        model_fn=vae_model_fn,
        config=config,
        train_batch_size=params["train_batch_size"],
        eval_batch_size=params["eval_batch_size"],
        predict_batch_size=params["predict_batch_size"],
        params=params)

    has_predict_or_eval_steps = params["predict_steps"] > 0 or params[
        "eval_steps"] > 0
    if has_predict_or_eval_steps:
        # Eval and train - stop and predict and/or eval every checkpoint
        while current_step < params["train_steps"]:
            next_checkpoint = min(
                current_step + params["steps_per_checkpoint"],
                params["train_steps"])
            estimator.train(input_fn=partial(vae_input_fn, eval=False),
                            max_steps=next_checkpoint)
            current_step = next_checkpoint
            logging.info(f"Current step: {current_step}")
            if params["predict_steps"] > 0:
                raise NotImplementedError
            if params["eval_steps"] > 0:
                logging.info(f"Starting eval")
                estimator.evaluate(input_fn=partial(vae_input_fn, eval=True),
                                   steps=params["eval_steps"])

        return
    else:
        # Else, just train
        while current_step < params["train_steps"]:
            # Else, don't stop and restart
            estimator.train(input_fn=partial(vae_input_fn, eval=False),
                            max_steps=params["train_steps"])
コード例 #9
0
ファイル: train_dalle.py プロジェクト: lucidrains/DALLE-mtf
def main():
    # parse args and params
    args = parse_args()
    logging = setup_logging(args)
    params = fetch_model_params(args.model)
    params["vae_params"] = fetch_model_params(params["vae_model"])
    assert params["model_type"].lower(
    ) == "dalle", f'model_type {params["model_type"]} not recognized'

    # Confirm deletion of checkpoint files if --new flag is set
    if args.new:
        maybe_remove_gs_or_filepath(params["model_path"])

    # get current step
    current_step = int(
        estimator_lib._load_global_step_from_checkpoint_dir(
            params["model_path"]))
    logging.info(f"Current step: {current_step}")

    # Add to params:
    mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
    params["num_cores"] = mesh_shape.size
    params["use_tpu"] = True if not args.tpu is None else False
    params["gpu_ids"] = args.gpu_ids
    tokenizer = get_tokenizer(params["tokenizer"])
    assert len(tokenizer) == params[
        "text_vocab_size"], f"tokenizer vocab size {len(tokenizer)} must equal model vocab size {params['text_vocab_size']}"
    params["padding_id"] = tokenizer.encode(tokenizer.pad_token)[0]
    # Set up TPUs and Estimator
    if args.tpu == "colab":
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        ) if params["use_tpu"] else None
    else:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            args.tpu) if params["use_tpu"] else None

    config = tpu_config.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=params["model_path"],
        save_checkpoints_steps=None,  # Disable the default saver
        save_checkpoints_secs=None,  # Disable the default saver
        log_step_count_steps=params["iterations"],
        save_summary_steps=params["iterations"],
        tpu_config=tpu_config.TPUConfig(
            num_shards=mesh_shape.size,
            iterations_per_loop=params["iterations"],
            num_cores_per_replica=1,
            experimental_host_call_every_n_steps=100,
            per_host_input_for_training=tpu_config.InputPipelineConfig.
            BROADCAST))

    estimator = tpu_estimator.TPUEstimator(
        use_tpu=params["use_tpu"],
        model_fn=dalle_model_fn,
        config=config,
        train_batch_size=params["train_batch_size"],
        eval_batch_size=params["eval_batch_size"],
        predict_batch_size=params["predict_batch_size"],
        params=params)

    has_predict_or_eval_steps = params["predict_steps"] > 0 or params[
        "eval_steps"] > 0
    if has_predict_or_eval_steps:
        # Eval and train - stop and predict and/or eval every checkpoint
        while current_step < params["train_steps"]:
            next_checkpoint = min(current_step + args.steps_per_checkpoint,
                                  params["train_steps"])
            estimator.train(input_fn=partial(dalle_input_fn, eval=False),
                            max_steps=next_checkpoint)
            current_step = next_checkpoint
            if params["predict_steps"] > 0:
                raise NotImplementedError
            if params["eval_steps"] > 0:
                raise NotImplementedError
        return
    else:
        # Else, just train
        while current_step < params["train_steps"]:
            # Else, don't stop and restart
            estimator.train(input_fn=partial(dalle_input_fn, eval=False),
                            max_steps=params["train_steps"])
コード例 #10
0
def main(args):
    # Setup logging
    logger = setup_logging(args)

    # Read params of model
    params = fetch_model_params(args.model)

    # Fetch appropriate input functions
    input_fn = generic_text
    pred_input_fn = pred_input
    handle_pred_output_fn = handle_pred_output

    if params["mlm_training"]:
        mlm_sample_text_fn = partial(mlm_sample_text, params)
        input_fn = partial(generic_text, sample_text_fn=mlm_sample_text_fn)

    # Fetch encoder per params
    encoder = fetch_encoder(params)

    pred_input_fn = partial(pred_input_fn,
                            path_to_prompt=args.prompt,
                            logger=logger,
                            enc=encoder)

    # Sample from Dataset if check dataset flag is on
    if args.check_dataset:
        check_dataset(input_fn)

    # Confirm deletion of checkpoint files if --new flag is set
    if args.new:
        if yes_or_no(
                f"Are you sure you want to remove '{params['model_path']}' to start afresh?"
        ):
            remove_gs_or_filepath(params["model_path"])
        else:
            exit()

    # Save config to logdir for experiment management
    save_config(params, params["model_path"])

    # Add to params: auto_layout, auto_layout_and_mesh_shape, use_tpu, num_cores
    mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
    params["num_cores"] = mesh_shape.size
    params["auto_layout"] = args.auto_layout
    params["auto_layout_and_mesh_shape"] = args.auto_layout_and_mesh_shape
    params["use_tpu"] = True if not args.tpu is None else False
    params["gpu_ids"] = args.gpu_ids
    params["steps_per_checkpoint"] = args.steps_per_checkpoint
    # Expand attention types param
    params["attention_types"] = expand_attention_types_params(
        params["attention_types"])
    assert len(params["attention_types"]) == params[
        "n_layer"]  # Assert that the length of expanded list = num layers
    params["predict_batch_size"] = params.get("predict_batch_size",
                                              1)  # Default to 1
    params["predict"] = args.predict
    params['model'] = params.get(
        "model", "GPT"
    )  # Default model selection to GPT since it's the only option for now

    # Sample quality of MoE models suffers when using the faster sampling method, so default to slow_sampling if
    # moe layers are present
    params[
        "slow_sampling"] = True if params["moe_layers"] is not None else False

    logger.info(f"params = {params}")

    # Get eval tasks from params
    eval_tasks = params.get("eval_tasks", [])
    has_predict_or_eval_steps_or_eval_tasks = params[
        "predict_steps"] > 0 or params["eval_steps"] > 0 or len(eval_tasks) > 0

    for t in eval_tasks:
        assert t in task_descriptors, f"Eval task '{t}' is not known"
        task_descriptors[t]["init_fn"](params)

    # Set up TPUs and Estimator
    if args.tpu == "colab":
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        ) if params["use_tpu"] else None
    else:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            args.tpu) if params["use_tpu"] else None

    config = tpu_config.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=params["model_path"],
        save_checkpoints_steps=None,  # Disable the default saver
        save_checkpoints_secs=None,  # Disable the default saver
        log_step_count_steps=params["iterations"],
        save_summary_steps=params["iterations"],
        tpu_config=tpu_config.TPUConfig(
            num_shards=mesh_shape.size,
            iterations_per_loop=params["iterations"],
            num_cores_per_replica=1,
            per_host_input_for_training=tpu_config.InputPipelineConfig.
            BROADCAST))

    estimator = tpu_estimator.TPUEstimator(
        use_tpu=params["use_tpu"],
        model_fn=model_fn,
        config=config,
        train_batch_size=params["train_batch_size"],
        eval_batch_size=params["train_batch_size"],
        predict_batch_size=params["predict_batch_size"],
        params=params)

    def _make_task_estimator(task):
        task_params = params.copy()
        task_params["eval_task"] = task
        return tpu_estimator.TPUEstimator(
            use_tpu=params["use_tpu"],
            model_fn=model_fn,
            config=config,
            train_batch_size=params["train_batch_size"],
            eval_batch_size=params["train_batch_size"],
            predict_batch_size=params["predict_batch_size"],
            params=task_params)

    eval_task_estimators = {
        task: _make_task_estimator(task)
        for task in eval_tasks
    }

    current_step = int(
        estimator_lib._load_global_step_from_checkpoint_dir(
            params["model_path"]))
    logger.info(f"Current step {current_step}")

    if args.predict:
        # Predict
        predictions = estimator.predict(input_fn=pred_input_fn)
        logger.info("Predictions generated")
        enc = fetch_encoder(params)
        handle_pred_output_fn(predictions,
                              logger,
                              enc,
                              params,
                              out_name=f"predictions_{current_step}")
        return

    elif has_predict_or_eval_steps_or_eval_tasks:
        # Eval and train - stop and predict and/or eval every checkpoint
        while current_step < params["train_steps"]:
            next_checkpoint = min(current_step + args.steps_per_checkpoint,
                                  params["train_steps"])

            estimator.train(input_fn=partial(input_fn, eval=False),
                            max_steps=next_checkpoint)
            current_step = next_checkpoint

            if params["predict_steps"] > 0:
                logger.info("Running prediction...")
                predictions = estimator.predict(input_fn=pred_input_fn)
                enc = fetch_encoder(params)
                handle_pred_output_fn(predictions,
                                      logger,
                                      enc,
                                      params,
                                      out_name=f"predictions_{current_step}")

            if params["eval_steps"] > 0:
                logger.info("Running evaluation...")
                eval_results = estimator.evaluate(input_fn=partial(input_fn,
                                                                   eval=True),
                                                  steps=params["eval_steps"])
                logger.info(f"Eval results: {eval_results}")

            for task in eval_tasks:
                logger.info(f"Starting evaluation task '{task}'")
                task_info = task_descriptors[task]["get_task_info_fn"](params)
                task_estimator = eval_task_estimators[task]
                task_input_fn = task_descriptors[task]["input_fn"]
                eval_results = task_estimator.evaluate(
                    input_fn=task_input_fn,
                    steps=task_info["n_steps"],
                    name=task)
                logger.info(f"Eval task '{task}' results: {eval_results}")
        return
    else:
        # Else, just train
        while current_step < params["train_steps"]:
            # Else, don't stop and restart
            estimator.train(input_fn=partial(input_fn, eval=False),
                            max_steps=params["train_steps"])
コード例 #11
0
ファイル: tpu.py プロジェクト: shawwn/lm
    def execute(self, job: TPUJobSpec):
        "execut the give job spec"
        cluster = self.resolve()

        run_config = tpu_config.RunConfig(
            cluster=cluster,
            model_dir=job.model_path,
            save_checkpoints_steps=None,  # Disable the default saver
            save_checkpoints_secs=None,  # Disable the default saver
            log_step_count_steps=job.params["steps_per_iteration"],
            save_summary_steps=job.params["steps_per_checkpoint"],
            tpu_config=tpu_config.TPUConfig(
                num_shards=job.function.mesh_shape.size,
                iterations_per_loop=job.params["steps_per_iteration"],
                num_cores_per_replica=1,
                per_host_input_for_training=tpu_config.InputPipelineConfig.
                BROADCAST,
            ),
        )

        estimator = tpu_estimator.TPUEstimator(
            use_tpu=job.use_tpu,
            model_fn=job.function,
            config=run_config,
            train_batch_size=job.infeed.
            batch_size,  # these change with the configuration
            eval_batch_size=job.infeed.batch_size,
            predict_batch_size=job.infeed.batch_size,
            params=job.params,
        )

        assert job.train or job.eval

        if job.train:
            if tf.io.gfile.exists(job.model_path):
                logging.info("restoring checkpoint steps from %s",
                             job.model_path)
                current_step = int(
                    estimator_lib._load_global_step_from_checkpoint_dir(
                        job.model_path))
                logging.info("current step is now at %d", current_step)
            else:
                current_step = 0

            while current_step < job.max_steps:
                estimator.train(input_fn=job.infeed.function,
                                max_steps=job.max_steps)
                current_step = int(
                    estimator_lib._load_global_step_from_checkpoint_dir(
                        job.model_path))
                logging.info("step %s", current_step)
            logging.info("completed device execution after %s steps",
                         current_step)

            if job.export:
                estimator.export_saved_model(job.export, job.signature)

            return {"current_step": current_step}

        if job.eval:
            # If eval is on - stop and eval every ckpt
            logging.info("starting to evaluate.")
            eval_results = estimator.evaluate(input_fn=job.infeed.function,
                                              steps=job.max_steps)
            logging.info("completed eval. results: %s", eval_results)
            return eval_results