예제 #1
0
    def after_run(self, run_context, run_values):
        global_step = run_values.results + 1
        if global_step >= self._last_step:
            # Check latest global step in the checkpoint to ensure that the targeted
            # last step is written on disk.

            step = estimator_lib._load_global_step_from_checkpoint_dir(
                self._model_dir)
            if step >= self._last_step:
                run_context.request_stop()
            else:
                time.sleep(self._wait_after_file_check_secs)
예제 #2
0
  def testAsyncCheckpointHookEnabled(self):
    resolver = tpu_cluster_resolver.TPUClusterResolver(
        tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)

    checkpoint_interval = 5
    config = tpu_config.RunConfig(
        master=resolver.master(),
        model_dir=os.path.join(FLAGS.model_dir, 'runconfig'),
        save_checkpoints_steps=1000,
        keep_checkpoint_max=11,  # off by one
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=checkpoint_interval,))

    estimator = tpu_estimator.TPUEstimator(
        use_tpu=True,
        model_fn=model_fn,
        config=config,
        train_batch_size=32,
        eval_batch_size=32,
        predict_batch_size=1,
        params={},
    )

    i = 10
    mock_listener = test.mock.create_autospec(
        basic_session_run_hooks.CheckpointSaverListener)
    estimator.train(
        input_fn=input_fn,
        max_steps=i * 10,
        hooks=[
            async_checkpoint.AsyncCheckpointSaverHook(
                FLAGS.model_dir,
                save_steps=checkpoint_interval,
                listeners=[mock_listener])
        ])

    current_step = estimator_lib._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access

    # TODO(power) -- identify a better way to count the number of checkpoints.
    checkpoints = file_io.get_matching_files(
        FLAGS.model_dir + '/model.ckpt*.meta')
    checkpoint_count = len(checkpoints)
    logging.info('Found %d checkpoints: %s', checkpoint_count, checkpoints)
    self.assertLessEqual(checkpoint_count, 10)
    self.assertEqual(current_step, i * 10)
    mock_listener.before_save.assert_called()
    mock_listener.after_save.assert_called()
예제 #3
0
def run_toy_model_tpu():
    """Run a toy model on TPU."""
    tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    iterations_per_loop = FLAGS.iterations
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    config = tpu_config.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=None,  # Disable the default saver
        save_checkpoints_secs=None,  # Disable the default saver
        log_step_count_steps=iterations_per_loop,
        save_summary_steps=iterations_per_loop,
        tpu_config=tpu_config.TPUConfig(
            num_shards=mesh_shape.size,
            iterations_per_loop=iterations_per_loop,
            num_cores_per_replica=1,
            per_host_input_for_training=tpu_config.InputPipelineConfig.
            BROADCAST))
    classifier = tpu_estimator.TPUEstimator(use_tpu=True,
                                            model_fn=model_fn,
                                            config=config,
                                            train_batch_size=FLAGS.batch_size,
                                            eval_batch_size=FLAGS.batch_size)
    current_step = estimator_lib._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    logging.info('Current step %d', current_step)
    if FLAGS.steps_per_checkpoint == 0:
        classifier.train(input_fn=ToyModelInput(), max_steps=FLAGS.train_steps)
        return
    while current_step < FLAGS.train_steps:
        next_checkpoint = min(current_step + FLAGS.steps_per_checkpoint,
                              FLAGS.train_steps)
        classifier.train(input_fn=ToyModelInput(), max_steps=next_checkpoint)
        current_step = next_checkpoint
        logging.info('Starting to evaluate.')
        eval_results = classifier.evaluate(
            input_fn=ToyModelInput(), steps=156
        )  # since we have 10000 examples and batch_size = 64 per host
        logging.info('Eval results: %s', eval_results)
예제 #4
0
def main():
    # parse args and params
    args = parse_args()
    logging = setup_logging(args)
    params = fetch_model_params(args.model)
    assert params["model_type"].lower(
    ) == "vae", f'model_type {params["model_type"]} not recognized'

    # Confirm deletion of checkpoint files if --new flag is set
    if args.new:
        maybe_remove_gs_or_filepath(params["model_path"])

    # get current step
    current_step = int(
        estimator_lib._load_global_step_from_checkpoint_dir(
            params["model_path"]))
    logging.info(f"Current step: {current_step}")

    # Add to params:
    params["use_tpu"] = True if not args.tpu is None else False
    params["gpu_ids"] = args.gpu_ids

    # Set up TPUs and Estimator
    if args.tpu == "colab":
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        ) if params["use_tpu"] else None
    else:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            args.tpu) if params["use_tpu"] else None

    config = tpu_config.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=params["model_path"],
        save_checkpoints_steps=params["steps_per_checkpoint"],
        log_step_count_steps=params["iterations"],
        save_summary_steps=params["iterations"],
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=params["iterations"],
            num_cores_per_replica=1,
            experimental_host_call_every_n_steps=100,
            per_host_input_for_training=tpu_config.InputPipelineConfig.
            BROADCAST))

    estimator = tpu_estimator.TPUEstimator(
        use_tpu=params["use_tpu"],
        model_fn=vae_model_fn,
        config=config,
        train_batch_size=params["train_batch_size"],
        eval_batch_size=params["eval_batch_size"],
        predict_batch_size=params["predict_batch_size"],
        params=params)

    has_predict_or_eval_steps = params["predict_steps"] > 0 or params[
        "eval_steps"] > 0
    if has_predict_or_eval_steps:
        # Eval and train - stop and predict and/or eval every checkpoint
        while current_step < params["train_steps"]:
            next_checkpoint = min(
                current_step + params["steps_per_checkpoint"],
                params["train_steps"])
            estimator.train(input_fn=partial(vae_input_fn, eval=False),
                            max_steps=next_checkpoint)
            current_step = next_checkpoint
            logging.info(f"Current step: {current_step}")
            if params["predict_steps"] > 0:
                raise NotImplementedError
            if params["eval_steps"] > 0:
                logging.info(f"Starting eval")
                estimator.evaluate(input_fn=partial(vae_input_fn, eval=True),
                                   steps=params["eval_steps"])

        return
    else:
        # Else, just train
        while current_step < params["train_steps"]:
            # Else, don't stop and restart
            estimator.train(input_fn=partial(vae_input_fn, eval=False),
                            max_steps=params["train_steps"])
예제 #5
0
def main(argv):
    del argv  # Unused.

    tf.enable_resource_variables()
    tf.set_random_seed(FLAGS.seed)
    set_lr_schedule()
    set_custom_sparsity_map()
    folder_stub = os.path.join(FLAGS.training_method, str(FLAGS.end_sparsity),
                               str(FLAGS.maskupdate_begin_step),
                               str(FLAGS.maskupdate_end_step),
                               str(FLAGS.maskupdate_frequency),
                               str(FLAGS.drop_fraction),
                               str(FLAGS.label_smoothing),
                               str(FLAGS.weight_decay))

    output_dir = FLAGS.output_dir
    if FLAGS.use_folder_stub:
        output_dir = os.path.join(output_dir, folder_stub)

    export_dir = os.path.join(output_dir, 'export_dir')

    # we pass the updated eval and train string to the params dictionary.
    params = {}
    params['output_dir'] = output_dir
    params['training_method'] = FLAGS.training_method
    params['use_tpu'] = FLAGS.use_tpu

    dataset_func = functools.partial(
        imagenet_input.ImageNetInput,
        data_dir=FLAGS.data_directory,
        transpose_input=False,
        num_parallel_calls=FLAGS.num_parallel_calls,
        use_bfloat16=False)
    imagenet_train, imagenet_eval = [
        dataset_func(is_training=is_training) for is_training in [True, False]
    ]

    run_config = tpu_config.RunConfig(
        master=FLAGS.master,
        model_dir=output_dir,
        save_checkpoints_steps=FLAGS.steps_per_checkpoint,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        session_config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False),
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores,
            tpu_job_name=FLAGS.tpu_job_name))

    classifier = tpu_estimator.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_model_fn_w_pruning,
        params=params,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    cpu_classifier = tpu_estimator.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_model_fn_w_pruning,
        params=params,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        export_to_tpu=False,
        eval_batch_size=FLAGS.eval_batch_size)

    if FLAGS.num_eval_images % FLAGS.eval_batch_size != 0:
        raise ValueError(
            'eval_batch_size (%d) must evenly divide num_eval_images(%d)!' %
            (FLAGS.eval_batch_size, FLAGS.num_eval_images))

    eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size
    if FLAGS.mode == 'eval_once':
        ckpt_path = os.path.join(output_dir, FLAGS.eval_once_ckpt_prefix)
        dataset = imagenet_train if FLAGS.eval_on_train else imagenet_eval
        classifier.evaluate(input_fn=dataset.input_fn,
                            steps=eval_steps,
                            checkpoint_path=ckpt_path,
                            name='{0}'.format(FLAGS.eval_once_ckpt_prefix))
    elif FLAGS.mode == 'eval':
        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(output_dir):
            tf.logging.info('Starting to evaluate.')
            try:
                dataset = imagenet_train if FLAGS.eval_on_train else imagenet_eval
                classifier.evaluate(input_fn=dataset.input_fn,
                                    steps=eval_steps,
                                    checkpoint_path=ckpt,
                                    name='eval')
                # Terminate eval job when final checkpoint is reached
                global_step = int(os.path.basename(ckpt).split('-')[1])
                if global_step >= FLAGS.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d' %
                        global_step)
                    break

            except tf.errors.NotFoundError:
                logging('Checkpoint no longer exists,skipping checkpoint.')

    else:
        global_step = estimator._load_global_step_from_checkpoint_dir(
            output_dir)
        # Session run hooks to export model for prediction
        export_hook = ExportModelHook(cpu_classifier, export_dir)
        hooks = [export_hook]

        if FLAGS.mode == 'train':
            tf.logging.info('start training...')
            classifier.train(input_fn=imagenet_train.input_fn,
                             hooks=hooks,
                             max_steps=FLAGS.train_steps)
        else:
            assert FLAGS.mode == 'train_and_eval'
            tf.logging.info('start training and eval...')
            while global_step < FLAGS.train_steps:
                next_checkpoint = min(global_step + FLAGS.steps_per_eval,
                                      FLAGS.train_steps)
                classifier.train(input_fn=imagenet_train.input_fn,
                                 max_steps=next_checkpoint)
                global_step = next_checkpoint
                logging('Completed training up to step :', global_step)
                classifier.evaluate(input_fn=imagenet_eval.input_fn,
                                    steps=eval_steps)
예제 #6
0
def main():
    # parse args and params
    args = parse_args()
    logging = setup_logging(args)
    params = fetch_model_params(args.model)
    params["vae_params"] = fetch_model_params(params["vae_model"])
    assert params["model_type"].lower(
    ) == "dalle", f'model_type {params["model_type"]} not recognized'

    # Confirm deletion of checkpoint files if --new flag is set
    if args.new:
        maybe_remove_gs_or_filepath(params["model_path"])

    # get current step
    current_step = int(
        estimator_lib._load_global_step_from_checkpoint_dir(
            params["model_path"]))
    logging.info(f"Current step: {current_step}")

    # Add to params:
    mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
    params["num_cores"] = mesh_shape.size
    params["use_tpu"] = True if not args.tpu is None else False
    params["gpu_ids"] = args.gpu_ids
    tokenizer = get_tokenizer(params["tokenizer"])
    assert len(tokenizer) == params[
        "text_vocab_size"], f"tokenizer vocab size {len(tokenizer)} must equal model vocab size {params['text_vocab_size']}"
    params["padding_id"] = tokenizer.encode(tokenizer.pad_token)[0]
    # Set up TPUs and Estimator
    if args.tpu == "colab":
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        ) if params["use_tpu"] else None
    else:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            args.tpu) if params["use_tpu"] else None

    config = tpu_config.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=params["model_path"],
        save_checkpoints_steps=None,  # Disable the default saver
        save_checkpoints_secs=None,  # Disable the default saver
        log_step_count_steps=params["iterations"],
        save_summary_steps=params["iterations"],
        tpu_config=tpu_config.TPUConfig(
            num_shards=mesh_shape.size,
            iterations_per_loop=params["iterations"],
            num_cores_per_replica=1,
            experimental_host_call_every_n_steps=100,
            per_host_input_for_training=tpu_config.InputPipelineConfig.
            BROADCAST))

    estimator = tpu_estimator.TPUEstimator(
        use_tpu=params["use_tpu"],
        model_fn=dalle_model_fn,
        config=config,
        train_batch_size=params["train_batch_size"],
        eval_batch_size=params["eval_batch_size"],
        predict_batch_size=params["predict_batch_size"],
        params=params)

    has_predict_or_eval_steps = params["predict_steps"] > 0 or params[
        "eval_steps"] > 0
    if has_predict_or_eval_steps:
        # Eval and train - stop and predict and/or eval every checkpoint
        while current_step < params["train_steps"]:
            next_checkpoint = min(current_step + args.steps_per_checkpoint,
                                  params["train_steps"])
            estimator.train(input_fn=partial(dalle_input_fn, eval=False),
                            max_steps=next_checkpoint)
            current_step = next_checkpoint
            if params["predict_steps"] > 0:
                raise NotImplementedError
            if params["eval_steps"] > 0:
                raise NotImplementedError
        return
    else:
        # Else, just train
        while current_step < params["train_steps"]:
            # Else, don't stop and restart
            estimator.train(input_fn=partial(dalle_input_fn, eval=False),
                            max_steps=params["train_steps"])
예제 #7
0
def main(args):
    # Setup logging
    logger = setup_logging(args)

    # Read params of model
    params = fetch_model_params(args.model)

    # Fetch appropriate input functions
    input_fn = generic_text
    pred_input_fn = pred_input
    handle_pred_output_fn = handle_pred_output

    if params["mlm_training"]:
        mlm_sample_text_fn = partial(mlm_sample_text, params)
        input_fn = partial(generic_text, sample_text_fn=mlm_sample_text_fn)

    # Fetch encoder per params
    encoder = fetch_encoder(params)

    pred_input_fn = partial(pred_input_fn,
                            path_to_prompt=args.prompt,
                            logger=logger,
                            enc=encoder)

    # Sample from Dataset if check dataset flag is on
    if args.check_dataset:
        check_dataset(input_fn)

    # Confirm deletion of checkpoint files if --new flag is set
    if args.new:
        if yes_or_no(
                f"Are you sure you want to remove '{params['model_path']}' to start afresh?"
        ):
            remove_gs_or_filepath(params["model_path"])
        else:
            exit()

    # Save config to logdir for experiment management
    save_config(params, params["model_path"])

    # Add to params: auto_layout, auto_layout_and_mesh_shape, use_tpu, num_cores
    mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
    params["num_cores"] = mesh_shape.size
    params["auto_layout"] = args.auto_layout
    params["auto_layout_and_mesh_shape"] = args.auto_layout_and_mesh_shape
    params["use_tpu"] = True if not args.tpu is None else False
    params["gpu_ids"] = args.gpu_ids
    params["steps_per_checkpoint"] = args.steps_per_checkpoint
    # Expand attention types param
    params["attention_types"] = expand_attention_types_params(
        params["attention_types"])
    assert len(params["attention_types"]) == params[
        "n_layer"]  # Assert that the length of expanded list = num layers
    params["predict_batch_size"] = params.get("predict_batch_size",
                                              1)  # Default to 1
    params["predict"] = args.predict
    params['model'] = params.get(
        "model", "GPT"
    )  # Default model selection to GPT since it's the only option for now

    # Sample quality of MoE models suffers when using the faster sampling method, so default to slow_sampling if
    # moe layers are present
    params[
        "slow_sampling"] = True if params["moe_layers"] is not None else False

    logger.info(f"params = {params}")

    # Get eval tasks from params
    eval_tasks = params.get("eval_tasks", [])
    has_predict_or_eval_steps_or_eval_tasks = params[
        "predict_steps"] > 0 or params["eval_steps"] > 0 or len(eval_tasks) > 0

    for t in eval_tasks:
        assert t in task_descriptors, f"Eval task '{t}' is not known"
        task_descriptors[t]["init_fn"](params)

    # Set up TPUs and Estimator
    if args.tpu == "colab":
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        ) if params["use_tpu"] else None
    else:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            args.tpu) if params["use_tpu"] else None

    config = tpu_config.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=params["model_path"],
        save_checkpoints_steps=None,  # Disable the default saver
        save_checkpoints_secs=None,  # Disable the default saver
        log_step_count_steps=params["iterations"],
        save_summary_steps=params["iterations"],
        tpu_config=tpu_config.TPUConfig(
            num_shards=mesh_shape.size,
            iterations_per_loop=params["iterations"],
            num_cores_per_replica=1,
            per_host_input_for_training=tpu_config.InputPipelineConfig.
            BROADCAST))

    estimator = tpu_estimator.TPUEstimator(
        use_tpu=params["use_tpu"],
        model_fn=model_fn,
        config=config,
        train_batch_size=params["train_batch_size"],
        eval_batch_size=params["train_batch_size"],
        predict_batch_size=params["predict_batch_size"],
        params=params)

    def _make_task_estimator(task):
        task_params = params.copy()
        task_params["eval_task"] = task
        return tpu_estimator.TPUEstimator(
            use_tpu=params["use_tpu"],
            model_fn=model_fn,
            config=config,
            train_batch_size=params["train_batch_size"],
            eval_batch_size=params["train_batch_size"],
            predict_batch_size=params["predict_batch_size"],
            params=task_params)

    eval_task_estimators = {
        task: _make_task_estimator(task)
        for task in eval_tasks
    }

    current_step = int(
        estimator_lib._load_global_step_from_checkpoint_dir(
            params["model_path"]))
    logger.info(f"Current step {current_step}")

    if args.predict:
        # Predict
        predictions = estimator.predict(input_fn=pred_input_fn)
        logger.info("Predictions generated")
        enc = fetch_encoder(params)
        handle_pred_output_fn(predictions,
                              logger,
                              enc,
                              params,
                              out_name=f"predictions_{current_step}")
        return

    elif has_predict_or_eval_steps_or_eval_tasks:
        # Eval and train - stop and predict and/or eval every checkpoint
        while current_step < params["train_steps"]:
            next_checkpoint = min(current_step + args.steps_per_checkpoint,
                                  params["train_steps"])

            estimator.train(input_fn=partial(input_fn, eval=False),
                            max_steps=next_checkpoint)
            current_step = next_checkpoint

            if params["predict_steps"] > 0:
                logger.info("Running prediction...")
                predictions = estimator.predict(input_fn=pred_input_fn)
                enc = fetch_encoder(params)
                handle_pred_output_fn(predictions,
                                      logger,
                                      enc,
                                      params,
                                      out_name=f"predictions_{current_step}")

            if params["eval_steps"] > 0:
                logger.info("Running evaluation...")
                eval_results = estimator.evaluate(input_fn=partial(input_fn,
                                                                   eval=True),
                                                  steps=params["eval_steps"])
                logger.info(f"Eval results: {eval_results}")

            for task in eval_tasks:
                logger.info(f"Starting evaluation task '{task}'")
                task_info = task_descriptors[task]["get_task_info_fn"](params)
                task_estimator = eval_task_estimators[task]
                task_input_fn = task_descriptors[task]["input_fn"]
                eval_results = task_estimator.evaluate(
                    input_fn=task_input_fn,
                    steps=task_info["n_steps"],
                    name=task)
                logger.info(f"Eval task '{task}' results: {eval_results}")
        return
    else:
        # Else, just train
        while current_step < params["train_steps"]:
            # Else, don't stop and restart
            estimator.train(input_fn=partial(input_fn, eval=False),
                            max_steps=params["train_steps"])
예제 #8
0
파일: tpu.py 프로젝트: shawwn/lm
    def execute(self, job: TPUJobSpec):
        "execut the give job spec"
        cluster = self.resolve()

        run_config = tpu_config.RunConfig(
            cluster=cluster,
            model_dir=job.model_path,
            save_checkpoints_steps=None,  # Disable the default saver
            save_checkpoints_secs=None,  # Disable the default saver
            log_step_count_steps=job.params["steps_per_iteration"],
            save_summary_steps=job.params["steps_per_checkpoint"],
            tpu_config=tpu_config.TPUConfig(
                num_shards=job.function.mesh_shape.size,
                iterations_per_loop=job.params["steps_per_iteration"],
                num_cores_per_replica=1,
                per_host_input_for_training=tpu_config.InputPipelineConfig.
                BROADCAST,
            ),
        )

        estimator = tpu_estimator.TPUEstimator(
            use_tpu=job.use_tpu,
            model_fn=job.function,
            config=run_config,
            train_batch_size=job.infeed.
            batch_size,  # these change with the configuration
            eval_batch_size=job.infeed.batch_size,
            predict_batch_size=job.infeed.batch_size,
            params=job.params,
        )

        assert job.train or job.eval

        if job.train:
            if tf.io.gfile.exists(job.model_path):
                logging.info("restoring checkpoint steps from %s",
                             job.model_path)
                current_step = int(
                    estimator_lib._load_global_step_from_checkpoint_dir(
                        job.model_path))
                logging.info("current step is now at %d", current_step)
            else:
                current_step = 0

            while current_step < job.max_steps:
                estimator.train(input_fn=job.infeed.function,
                                max_steps=job.max_steps)
                current_step = int(
                    estimator_lib._load_global_step_from_checkpoint_dir(
                        job.model_path))
                logging.info("step %s", current_step)
            logging.info("completed device execution after %s steps",
                         current_step)

            if job.export:
                estimator.export_saved_model(job.export, job.signature)

            return {"current_step": current_step}

        if job.eval:
            # If eval is on - stop and eval every ckpt
            logging.info("starting to evaluate.")
            eval_results = estimator.evaluate(input_fn=job.infeed.function,
                                              steps=job.max_steps)
            logging.info("completed eval. results: %s", eval_results)
            return eval_results