def main(_):
    os.environ["TFHUB_CACHE_DIR"] = os.path.join(FLAGS.model_dir,
                                                 "tfhub_modules")
    os.environ["TFHUB_DOWNLOAD_PROGRESS"] = "True"
    input_fn = partial(input_, iterations=FLAGS.iterations)
    cluster = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    run_config = tpu_config.RunConfig(model_dir=FLAGS.model_dir,
                                      cluster=cluster,
                                      tpu_config=tpu_config.TPUConfig(
                                          FLAGS.iterations))

    classifier = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                            use_tpu=FLAGS.use_tpu,
                                            train_batch_size=FLAGS.batch_size,
                                            eval_batch_size=FLAGS.batch_size,
                                            config=run_config,
                                            params={
                                                "use_tpu": FLAGS.use_tpu,
                                                "data_dir": FLAGS.data_dir,
                                                "dataset": FLAGS.dataset
                                            })

    classifier.train(input_fn=lambda params: input_fn(
        mode=tf.estimator.ModeKeys.TRAIN, **params),
                     max_steps=2000)  #.evaluate(
예제 #2
0
def main(_):
  config = params_dict.ParamsDict(mask_rcnn_config.MASK_RCNN_CFG,
                                  mask_rcnn_config.MASK_RCNN_RESTRICTIONS)
  config = params_dict.override_params_dict(
      config, FLAGS.config, is_strict=True)
  config.is_training_bn = False
  config.train_batch_size = FLAGS.batch_size
  config.eval_batch_size = FLAGS.batch_size

  config.validate()
  config.lock()

  model_params = dict(
      list(config.as_dict().items()),
      use_tpu=FLAGS.use_tpu,
      mode=tf.estimator.ModeKeys.PREDICT,
      transpose_input=False)

  print(' - Setting up TPUEstimator...')
  estimator = tf.estimator.tpu.TPUEstimator(
      model_fn=serving.serving_model_fn_builder(
          FLAGS.output_source_id, FLAGS.output_image_info,
          FLAGS.output_box_features, FLAGS.output_normalized_coordinates,
          FLAGS.cast_num_detections_to_float),
      model_dir=FLAGS.model_dir,
      config=tpu_config.RunConfig(
          tpu_config=tpu_config.TPUConfig(
              iterations_per_loop=FLAGS.iterations_per_loop),
          master='local',
          evaluation_master='local'),
      params=model_params,
      use_tpu=FLAGS.use_tpu,
      train_batch_size=FLAGS.batch_size,
      predict_batch_size=FLAGS.batch_size,
      export_to_tpu=FLAGS.use_tpu,
      export_to_cpu=True)

  print(' - Exporting the model...')
  input_type = FLAGS.input_type
  export_path = estimator.export_saved_model(
      export_dir_base=FLAGS.export_dir,
      serving_input_receiver_fn=functools.partial(
          serving.serving_input_fn,
          batch_size=FLAGS.batch_size,
          desired_image_size=config.image_size,
          padding_stride=(2**config.max_level),
          input_type=input_type,
          input_name=FLAGS.input_name),
      checkpoint_path=FLAGS.checkpoint_path)

  if FLAGS.add_warmup_requests and input_type == 'image_bytes':
    inference_warmup.write_warmup_requests(
        export_path,
        FLAGS.model_name,
        config.image_size,
        batch_sizes=[FLAGS.batch_size],
        image_format='JPEG',
        input_signature=FLAGS.input_name)
  print(' - Done! path: %s' % export_path)
예제 #3
0
 def test_evaluation_master_defaults_to_master_in_tf_config(self):
   tf_config = {
       'session_master': '_master_123',
   }
   with _set_tf_config_env_variable(tf_config):
     run_config = tpu_config_lib.RunConfig()
     self.assertEqual('_master_123', run_config.master)
     self.assertEqual('_master_123', run_config.evaluation_master)
예제 #4
0
 def test_user_overwrites_master_in_tf_config(self):
   tf_config = {
       'session_master': '_master_123',
       'eval_session_master': '_eval_master_123'
   }
   with _set_tf_config_env_variable(tf_config):
     run_config = tpu_config_lib.RunConfig(master='_new_master_123')
     self.assertEqual('_new_master_123', run_config.master)
     self.assertEqual('_eval_master_123', run_config.evaluation_master)
예제 #5
0
 def test_respect_evaluation_master_in_tf_config(self):
   tf_config = {
       'cluster': {
           run_config_lib.TaskType.CHIEF: ['host0:0'],
       },
       'task': {
           'type': run_config_lib.TaskType.EVALUATOR,
           'index': 0
       },
   }
   with _set_tf_config_env_variable(tf_config):
     run_config = tpu_config_lib.RunConfig(master='_something')
     self.assertEqual('', run_config.evaluation_master)
예제 #6
0
 def test_no_session_config_set_with_cluster_spec(self):
   tf_config = {
       'cluster': {
           run_config_lib.TaskType.CHIEF: ['host3:3'],
           run_config_lib.TaskType.WORKER: ['host3:4']
       },
       'task': {
           'type': run_config_lib.TaskType.CHIEF,
           'index': 0
       }
   }
   with _set_tf_config_env_variable(tf_config):
     run_config = tpu_config_lib.RunConfig()
     self.assertIsNone(run_config.session_config)
예제 #7
0
 def test_no_session_config_overwrite_with_cluster_spec(self):
   tf_config = {
       'cluster': {
           run_config_lib.TaskType.CHIEF: ['host3:3'],
           run_config_lib.TaskType.WORKER: ['host3:4']
       },
       'task': {
           'type': run_config_lib.TaskType.CHIEF,
           'index': 0
       }
   }
   with _set_tf_config_env_variable(tf_config):
     session_config = config_pb2.ConfigProto(allow_soft_placement=True)
     run_config = tpu_config_lib.RunConfig(session_config=session_config)
     self.assertEqual(session_config, run_config.session_config)
예제 #8
0
  def testAsyncCheckpointHookEnabled(self):
    resolver = tpu_cluster_resolver.TPUClusterResolver(
        tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)

    checkpoint_interval = 5
    config = tpu_config.RunConfig(
        master=resolver.master(),
        model_dir=os.path.join(FLAGS.model_dir, 'runconfig'),
        save_checkpoints_steps=1000,
        keep_checkpoint_max=11,  # off by one
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=checkpoint_interval,))

    estimator = tpu_estimator.TPUEstimator(
        use_tpu=True,
        model_fn=model_fn,
        config=config,
        train_batch_size=32,
        eval_batch_size=32,
        predict_batch_size=1,
        params={},
    )

    i = 10
    mock_listener = test.mock.create_autospec(
        basic_session_run_hooks.CheckpointSaverListener)
    estimator.train(
        input_fn=input_fn,
        max_steps=i * 10,
        hooks=[
            async_checkpoint.AsyncCheckpointSaverHook(
                FLAGS.model_dir,
                save_steps=checkpoint_interval,
                listeners=[mock_listener])
        ])

    current_step = estimator_lib._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access

    # TODO(power) -- identify a better way to count the number of checkpoints.
    checkpoints = file_io.get_matching_files(
        FLAGS.model_dir + '/model.ckpt*.meta')
    checkpoint_count = len(checkpoints)
    logging.info('Found %d checkpoints: %s', checkpoint_count, checkpoints)
    self.assertLessEqual(checkpoint_count, 10)
    self.assertEqual(current_step, i * 10)
    mock_listener.before_save.assert_called()
    mock_listener.after_save.assert_called()
예제 #9
0
def main(_):
    os.environ["TFHUB_CACHE_DIR"] = os.path.join(FLAGS.model_dir,
                                                 "tfhub_modules")
    os.environ["TFHUB_DOWNLOAD_PROGRESS"] = "True"
    input_fn = partial(input_, iterations=FLAGS.iterations)
    cluster = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    run_config = tpu_config.RunConfig(model_dir=FLAGS.model_dir,
                                      cluster=cluster,
                                      tpu_config=tpu_config.TPUConfig(
                                          FLAGS.iterations))

    classifier = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                            use_tpu=FLAGS.use_tpu,
                                            train_batch_size=FLAGS.batch_size,
                                            eval_batch_size=FLAGS.batch_size,
                                            config=run_config,
                                            params={
                                                "use_tpu":
                                                FLAGS.use_tpu,
                                                "data_dir":
                                                FLAGS.data_dir,
                                                "dataset":
                                                FLAGS.dataset,
                                                "use_compat":
                                                FLAGS.use_compat,
                                                "learning_rate":
                                                FLAGS.learning_rate
                                            })
    try:
        classifier.train(input_fn=lambda params: input_fn(
            mode=tf.estimator.ModeKeys.TRAIN, **params),
                         max_steps=FLAGS.max_steps)
    except Exception:
        pass
    if FLAGS.infer:

        def prepare_input_fn(path):
            img = tf.image.decode_image(tf.io.read_file(path))
            return resize_and_scale(img, None)

        predictions = classifer.predict(
            input_fn=lambda params: prepare_input_fn(FLAGS.infer))
        print(predictions)
예제 #10
0
def run_toy_model_tpu():
    """Run a toy model on TPU."""
    tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    iterations_per_loop = FLAGS.iterations
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    config = tpu_config.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=None,  # Disable the default saver
        save_checkpoints_secs=None,  # Disable the default saver
        log_step_count_steps=iterations_per_loop,
        save_summary_steps=iterations_per_loop,
        tpu_config=tpu_config.TPUConfig(
            num_shards=mesh_shape.size,
            iterations_per_loop=iterations_per_loop,
            num_cores_per_replica=1,
            per_host_input_for_training=tpu_config.InputPipelineConfig.
            BROADCAST))
    classifier = tpu_estimator.TPUEstimator(use_tpu=True,
                                            model_fn=model_fn,
                                            config=config,
                                            train_batch_size=FLAGS.batch_size,
                                            eval_batch_size=FLAGS.batch_size)
    current_step = estimator_lib._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    logging.info('Current step %d', current_step)
    if FLAGS.steps_per_checkpoint == 0:
        classifier.train(input_fn=ToyModelInput(), max_steps=FLAGS.train_steps)
        return
    while current_step < FLAGS.train_steps:
        next_checkpoint = min(current_step + FLAGS.steps_per_checkpoint,
                              FLAGS.train_steps)
        classifier.train(input_fn=ToyModelInput(), max_steps=next_checkpoint)
        current_step = next_checkpoint
        logging.info('Starting to evaluate.')
        eval_results = classifier.evaluate(
            input_fn=ToyModelInput(), steps=156
        )  # since we have 10000 examples and batch_size = 64 per host
        logging.info('Eval results: %s', eval_results)
예제 #11
0
 def test_default_values(self):
   run_config = tpu_config_lib.RunConfig()
   self.assertEqual('', run_config.master)
   self.assertEqual('', run_config.evaluation_master)
예제 #12
0
 def test_fail_with_iterations_per_loop(self):
   with self.assertRaisesRegexp(ValueError, 'must be positive'):
     tpu_config_lib.RunConfig(
         tpu_config=tpu_config_lib.TPUConfig(iterations_per_loop=0))
예제 #13
0
 def test_fail_with_invalid_num_shards(self):
   with self.assertRaisesRegexp(ValueError, 'must be positive'):
     tpu_config_lib.RunConfig(
         tpu_config=tpu_config_lib.TPUConfig(num_shards=0))
예제 #14
0
def main(args):
    # Setup logging
    logger = setup_logging(args)

    # Read params of model
    params = fetch_model_params(args.model)

    # Fetch appropriate input functions
    input_fn = generic_text
    pred_input_fn = pred_input
    handle_pred_output_fn = handle_pred_output

    if params["mlm_training"]:
        mlm_sample_text_fn = partial(mlm_sample_text, params)
        input_fn = partial(generic_text, sample_text_fn=mlm_sample_text_fn)

    # Fetch encoder per params
    encoder = fetch_encoder(params)

    pred_input_fn = partial(pred_input_fn,
                            path_to_prompt=args.prompt,
                            logger=logger,
                            enc=encoder)

    # Sample from Dataset if check dataset flag is on
    if args.check_dataset:
        check_dataset(input_fn)

    # Confirm deletion of checkpoint files if --new flag is set
    if args.new:
        if yes_or_no(
                f"Are you sure you want to remove '{params['model_path']}' to start afresh?"
        ):
            remove_gs_or_filepath(params["model_path"])
        else:
            exit()

    # Save config to logdir for experiment management
    save_config(params, params["model_path"])

    # Add to params: auto_layout, auto_layout_and_mesh_shape, use_tpu, num_cores
    mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
    params["num_cores"] = mesh_shape.size
    params["auto_layout"] = args.auto_layout
    params["auto_layout_and_mesh_shape"] = args.auto_layout_and_mesh_shape
    params["use_tpu"] = True if not args.tpu is None else False
    params["gpu_ids"] = args.gpu_ids
    params["steps_per_checkpoint"] = args.steps_per_checkpoint
    # Expand attention types param
    params["attention_types"] = expand_attention_types_params(
        params["attention_types"])
    assert len(params["attention_types"]) == params[
        "n_layer"]  # Assert that the length of expanded list = num layers
    params["predict_batch_size"] = params.get("predict_batch_size",
                                              1)  # Default to 1
    params["predict"] = args.predict
    params['model'] = params.get(
        "model", "GPT"
    )  # Default model selection to GPT since it's the only option for now

    # Sample quality of MoE models suffers when using the faster sampling method, so default to slow_sampling if
    # moe layers are present
    params[
        "slow_sampling"] = True if params["moe_layers"] is not None else False

    logger.info(f"params = {params}")

    # Get eval tasks from params
    eval_tasks = params.get("eval_tasks", [])
    has_predict_or_eval_steps_or_eval_tasks = params[
        "predict_steps"] > 0 or params["eval_steps"] > 0 or len(eval_tasks) > 0

    for t in eval_tasks:
        assert t in task_descriptors, f"Eval task '{t}' is not known"
        task_descriptors[t]["init_fn"](params)

    # Set up TPUs and Estimator
    if args.tpu == "colab":
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        ) if params["use_tpu"] else None
    else:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            args.tpu) if params["use_tpu"] else None

    config = tpu_config.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=params["model_path"],
        save_checkpoints_steps=None,  # Disable the default saver
        save_checkpoints_secs=None,  # Disable the default saver
        log_step_count_steps=params["iterations"],
        save_summary_steps=params["iterations"],
        tpu_config=tpu_config.TPUConfig(
            num_shards=mesh_shape.size,
            iterations_per_loop=params["iterations"],
            num_cores_per_replica=1,
            per_host_input_for_training=tpu_config.InputPipelineConfig.
            BROADCAST))

    estimator = tpu_estimator.TPUEstimator(
        use_tpu=params["use_tpu"],
        model_fn=model_fn,
        config=config,
        train_batch_size=params["train_batch_size"],
        eval_batch_size=params["train_batch_size"],
        predict_batch_size=params["predict_batch_size"],
        params=params)

    def _make_task_estimator(task):
        task_params = params.copy()
        task_params["eval_task"] = task
        return tpu_estimator.TPUEstimator(
            use_tpu=params["use_tpu"],
            model_fn=model_fn,
            config=config,
            train_batch_size=params["train_batch_size"],
            eval_batch_size=params["train_batch_size"],
            predict_batch_size=params["predict_batch_size"],
            params=task_params)

    eval_task_estimators = {
        task: _make_task_estimator(task)
        for task in eval_tasks
    }

    current_step = int(
        estimator_lib._load_global_step_from_checkpoint_dir(
            params["model_path"]))
    logger.info(f"Current step {current_step}")

    if args.predict:
        # Predict
        predictions = estimator.predict(input_fn=pred_input_fn)
        logger.info("Predictions generated")
        enc = fetch_encoder(params)
        handle_pred_output_fn(predictions,
                              logger,
                              enc,
                              params,
                              out_name=f"predictions_{current_step}")
        return

    elif has_predict_or_eval_steps_or_eval_tasks:
        # Eval and train - stop and predict and/or eval every checkpoint
        while current_step < params["train_steps"]:
            next_checkpoint = min(current_step + args.steps_per_checkpoint,
                                  params["train_steps"])

            estimator.train(input_fn=partial(input_fn, eval=False),
                            max_steps=next_checkpoint)
            current_step = next_checkpoint

            if params["predict_steps"] > 0:
                logger.info("Running prediction...")
                predictions = estimator.predict(input_fn=pred_input_fn)
                enc = fetch_encoder(params)
                handle_pred_output_fn(predictions,
                                      logger,
                                      enc,
                                      params,
                                      out_name=f"predictions_{current_step}")

            if params["eval_steps"] > 0:
                logger.info("Running evaluation...")
                eval_results = estimator.evaluate(input_fn=partial(input_fn,
                                                                   eval=True),
                                                  steps=params["eval_steps"])
                logger.info(f"Eval results: {eval_results}")

            for task in eval_tasks:
                logger.info(f"Starting evaluation task '{task}'")
                task_info = task_descriptors[task]["get_task_info_fn"](params)
                task_estimator = eval_task_estimators[task]
                task_input_fn = task_descriptors[task]["input_fn"]
                eval_results = task_estimator.evaluate(
                    input_fn=task_input_fn,
                    steps=task_info["n_steps"],
                    name=task)
                logger.info(f"Eval task '{task}' results: {eval_results}")
        return
    else:
        # Else, just train
        while current_step < params["train_steps"]:
            # Else, don't stop and restart
            estimator.train(input_fn=partial(input_fn, eval=False),
                            max_steps=params["train_steps"])
예제 #15
0
def main():
    # parse args and params
    args = parse_args()
    logging = setup_logging(args)
    params = fetch_model_params(args.model)
    params["vae_params"] = fetch_model_params(params["vae_model"])
    assert params["model_type"].lower(
    ) == "dalle", f'model_type {params["model_type"]} not recognized'

    # Confirm deletion of checkpoint files if --new flag is set
    if args.new:
        maybe_remove_gs_or_filepath(params["model_path"])

    # get current step
    current_step = int(
        estimator_lib._load_global_step_from_checkpoint_dir(
            params["model_path"]))
    logging.info(f"Current step: {current_step}")

    # Add to params:
    mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
    params["num_cores"] = mesh_shape.size
    params["use_tpu"] = True if not args.tpu is None else False
    params["gpu_ids"] = args.gpu_ids
    tokenizer = get_tokenizer(params["tokenizer"])
    assert len(tokenizer) == params[
        "text_vocab_size"], f"tokenizer vocab size {len(tokenizer)} must equal model vocab size {params['text_vocab_size']}"
    params["padding_id"] = tokenizer.encode(tokenizer.pad_token)[0]
    # Set up TPUs and Estimator
    if args.tpu == "colab":
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        ) if params["use_tpu"] else None
    else:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            args.tpu) if params["use_tpu"] else None

    config = tpu_config.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=params["model_path"],
        save_checkpoints_steps=None,  # Disable the default saver
        save_checkpoints_secs=None,  # Disable the default saver
        log_step_count_steps=params["iterations"],
        save_summary_steps=params["iterations"],
        tpu_config=tpu_config.TPUConfig(
            num_shards=mesh_shape.size,
            iterations_per_loop=params["iterations"],
            num_cores_per_replica=1,
            experimental_host_call_every_n_steps=100,
            per_host_input_for_training=tpu_config.InputPipelineConfig.
            BROADCAST))

    estimator = tpu_estimator.TPUEstimator(
        use_tpu=params["use_tpu"],
        model_fn=dalle_model_fn,
        config=config,
        train_batch_size=params["train_batch_size"],
        eval_batch_size=params["eval_batch_size"],
        predict_batch_size=params["predict_batch_size"],
        params=params)

    has_predict_or_eval_steps = params["predict_steps"] > 0 or params[
        "eval_steps"] > 0
    if has_predict_or_eval_steps:
        # Eval and train - stop and predict and/or eval every checkpoint
        while current_step < params["train_steps"]:
            next_checkpoint = min(current_step + args.steps_per_checkpoint,
                                  params["train_steps"])
            estimator.train(input_fn=partial(dalle_input_fn, eval=False),
                            max_steps=next_checkpoint)
            current_step = next_checkpoint
            if params["predict_steps"] > 0:
                raise NotImplementedError
            if params["eval_steps"] > 0:
                raise NotImplementedError
        return
    else:
        # Else, just train
        while current_step < params["train_steps"]:
            # Else, don't stop and restart
            estimator.train(input_fn=partial(dalle_input_fn, eval=False),
                            max_steps=params["train_steps"])
예제 #16
0
 def test_no_session_config_overwrite_in_local_case(self):
   session_config = config_pb2.ConfigProto(allow_soft_placement=True)
   run_config = tpu_config_lib.RunConfig(session_config=session_config)
   self.assertEqual(session_config, run_config.session_config)
예제 #17
0
 def test_no_session_config_set_in_local_case(self):
   run_config = tpu_config_lib.RunConfig()
   self.assertIsNone(run_config.session_config)
예제 #18
0
 def test_with_tf_config(self):
   tf_config = {'service': {'tpu_worker_job_name': '_my_new_name',}}
   with _set_tf_config_env_variable(tf_config):
     config = tpu_config_lib.RunConfig()
     self.assertEqual('_my_new_name', config.tpu_config.tpu_job_name)
예제 #19
0
파일: tpu.py 프로젝트: shawwn/lm
    def execute(self, job: TPUJobSpec):
        "execut the give job spec"
        cluster = self.resolve()

        run_config = tpu_config.RunConfig(
            cluster=cluster,
            model_dir=job.model_path,
            save_checkpoints_steps=None,  # Disable the default saver
            save_checkpoints_secs=None,  # Disable the default saver
            log_step_count_steps=job.params["steps_per_iteration"],
            save_summary_steps=job.params["steps_per_checkpoint"],
            tpu_config=tpu_config.TPUConfig(
                num_shards=job.function.mesh_shape.size,
                iterations_per_loop=job.params["steps_per_iteration"],
                num_cores_per_replica=1,
                per_host_input_for_training=tpu_config.InputPipelineConfig.
                BROADCAST,
            ),
        )

        estimator = tpu_estimator.TPUEstimator(
            use_tpu=job.use_tpu,
            model_fn=job.function,
            config=run_config,
            train_batch_size=job.infeed.
            batch_size,  # these change with the configuration
            eval_batch_size=job.infeed.batch_size,
            predict_batch_size=job.infeed.batch_size,
            params=job.params,
        )

        assert job.train or job.eval

        if job.train:
            if tf.io.gfile.exists(job.model_path):
                logging.info("restoring checkpoint steps from %s",
                             job.model_path)
                current_step = int(
                    estimator_lib._load_global_step_from_checkpoint_dir(
                        job.model_path))
                logging.info("current step is now at %d", current_step)
            else:
                current_step = 0

            while current_step < job.max_steps:
                estimator.train(input_fn=job.infeed.function,
                                max_steps=job.max_steps)
                current_step = int(
                    estimator_lib._load_global_step_from_checkpoint_dir(
                        job.model_path))
                logging.info("step %s", current_step)
            logging.info("completed device execution after %s steps",
                         current_step)

            if job.export:
                estimator.export_saved_model(job.export, job.signature)

            return {"current_step": current_step}

        if job.eval:
            # If eval is on - stop and eval every ckpt
            logging.info("starting to evaluate.")
            eval_results = estimator.evaluate(input_fn=job.infeed.function,
                                              steps=job.max_steps)
            logging.info("completed eval. results: %s", eval_results)
            return eval_results
예제 #20
0
 def test_user_provided_master_and_evaluation_master(self):
   run_config = tpu_config_lib.RunConfig(
       master='_master_123', evaluation_master='_eval_master_123')
   self.assertEqual('_master_123', run_config.master)
   self.assertEqual('_eval_master_123', run_config.evaluation_master)
예제 #21
0
 def test_evaluation_master_defaults_to_master(self):
   run_config = tpu_config_lib.RunConfig(master='_master_123')
   self.assertEqual('_master_123', run_config.master)
   self.assertEqual('_master_123', run_config.evaluation_master)
예제 #22
0
def main():
    # parse args and params
    args = parse_args()
    logging = setup_logging(args)
    params = fetch_model_params(args.model)
    assert params["model_type"].lower(
    ) == "vae", f'model_type {params["model_type"]} not recognized'

    # Confirm deletion of checkpoint files if --new flag is set
    if args.new:
        maybe_remove_gs_or_filepath(params["model_path"])

    # get current step
    current_step = int(
        estimator_lib._load_global_step_from_checkpoint_dir(
            params["model_path"]))
    logging.info(f"Current step: {current_step}")

    # Add to params:
    params["use_tpu"] = True if not args.tpu is None else False
    params["gpu_ids"] = args.gpu_ids

    # Set up TPUs and Estimator
    if args.tpu == "colab":
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        ) if params["use_tpu"] else None
    else:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            args.tpu) if params["use_tpu"] else None

    config = tpu_config.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=params["model_path"],
        save_checkpoints_steps=params["steps_per_checkpoint"],
        log_step_count_steps=params["iterations"],
        save_summary_steps=params["iterations"],
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=params["iterations"],
            num_cores_per_replica=1,
            experimental_host_call_every_n_steps=100,
            per_host_input_for_training=tpu_config.InputPipelineConfig.
            BROADCAST))

    estimator = tpu_estimator.TPUEstimator(
        use_tpu=params["use_tpu"],
        model_fn=vae_model_fn,
        config=config,
        train_batch_size=params["train_batch_size"],
        eval_batch_size=params["eval_batch_size"],
        predict_batch_size=params["predict_batch_size"],
        params=params)

    has_predict_or_eval_steps = params["predict_steps"] > 0 or params[
        "eval_steps"] > 0
    if has_predict_or_eval_steps:
        # Eval and train - stop and predict and/or eval every checkpoint
        while current_step < params["train_steps"]:
            next_checkpoint = min(
                current_step + params["steps_per_checkpoint"],
                params["train_steps"])
            estimator.train(input_fn=partial(vae_input_fn, eval=False),
                            max_steps=next_checkpoint)
            current_step = next_checkpoint
            logging.info(f"Current step: {current_step}")
            if params["predict_steps"] > 0:
                raise NotImplementedError
            if params["eval_steps"] > 0:
                logging.info(f"Starting eval")
                estimator.evaluate(input_fn=partial(vae_input_fn, eval=True),
                                   steps=params["eval_steps"])

        return
    else:
        # Else, just train
        while current_step < params["train_steps"]:
            # Else, don't stop and restart
            estimator.train(input_fn=partial(vae_input_fn, eval=False),
                            max_steps=params["train_steps"])
예제 #23
0
 def test_default_name(self):
   config = tpu_config_lib.RunConfig()
   self.assertIsNone(config.tpu_config.tpu_job_name)