コード例 #1
0
ファイル: run_experiment.py プロジェクト: zxhjiutian/gpt-neo
parser.add_argument('--steps_per_checkpoint', type=int, default=5000)
parser.add_argument('--autostack', action="store_false")
parser.add_argument('--auto_layout', action="store_true")
parser.add_argument('--auto_layout_and_mesh_shape', action="store_true")
parser.add_argument('--new', action='store_true')
parser.add_argument('--test', action='store_true')
parser.add_argument('--eval', action='store_true')
parser.add_argument('--predict', action='store_true')
parser.add_argument('--no_delete_tpu', action='store_true')
parser.add_argument('--initial_heartbeat_timeout', type=int, default=7200)
parser.add_argument(
    '--heartbeat_timeout', type=int, default=1800
)  # kill and restart if nothing logged to tensorboard in this many seconds
args = parser.parse_args()

params = fetch_model_params(args.model)

ex = sacred.Experiment(args.experiment_name)
ex.observers.append(
    sacred.observers.QueuedMongoObserver(url='127.0.0.1:27017',
                                         db_name='db',
                                         username='******',
                                         password='******'))


def get_open_port(lo=8000, hi=8100):
    for i in range(lo, hi):
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            if s.connect_ex(('localhost', i)) != 0:
                return i
コード例 #2
0
def main(args):
    # Setup logging
    logger = setup_logging(args)

    # Read params of model
    params = fetch_model_params(args.model)

    # Fetch appropriate input functions
    input_fn = generic_text
    pred_input_fn = pred_input
    handle_pred_output_fn = handle_pred_output

    if params["mlm_training"]:
        mlm_sample_text_fn = partial(mlm_sample_text, params)
        input_fn = partial(generic_text, sample_text_fn=mlm_sample_text_fn)

    # Fetch encoder per params
    encoder = fetch_encoder(params)

    pred_input_fn = partial(pred_input_fn,
                            path_to_prompt=args.prompt,
                            logger=logger,
                            enc=encoder)

    # Sample from Dataset if check dataset flag is on
    if args.check_dataset:
        check_dataset(input_fn)

    # Confirm deletion of checkpoint files if --new flag is set
    if args.new:
        if yes_or_no(
                f"Are you sure you want to remove '{params['model_path']}' to start afresh?"
        ):
            remove_gs_or_filepath(params["model_path"])
        else:
            exit()

    # Save config to logdir for experiment management
    save_config(params, params["model_path"])

    # Add to params: auto_layout, auto_layout_and_mesh_shape, use_tpu, num_cores
    mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
    params["num_cores"] = mesh_shape.size
    params["auto_layout"] = args.auto_layout
    params["auto_layout_and_mesh_shape"] = args.auto_layout_and_mesh_shape
    params["use_tpu"] = True if not args.tpu is None else False
    params["gpu_ids"] = args.gpu_ids
    params["steps_per_checkpoint"] = args.steps_per_checkpoint
    # Expand attention types param
    params["attention_types"] = expand_attention_types_params(
        params["attention_types"])
    assert len(params["attention_types"]) == params[
        "n_layer"]  # Assert that the length of expanded list = num layers
    params["predict_batch_size"] = params.get("predict_batch_size",
                                              1)  # Default to 1
    params["predict"] = args.predict
    params['model'] = params.get(
        "model", "GPT"
    )  # Default model selection to GPT since it's the only option for now

    # Sample quality of MoE models suffers when using the faster sampling method, so default to slow_sampling if
    # moe layers are present
    params[
        "slow_sampling"] = True if params["moe_layers"] is not None else False

    logger.info(f"params = {params}")

    # Get eval tasks from params
    eval_tasks = params.get("eval_tasks", [])
    has_predict_or_eval_steps_or_eval_tasks = params[
        "predict_steps"] > 0 or params["eval_steps"] > 0 or len(eval_tasks) > 0

    for t in eval_tasks:
        assert t in task_descriptors, f"Eval task '{t}' is not known"
        task_descriptors[t]["init_fn"](params)

    # Set up TPUs and Estimator
    if args.tpu == "colab":
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        ) if params["use_tpu"] else None
    else:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            args.tpu) if params["use_tpu"] else None

    config = tpu_config.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=params["model_path"],
        save_checkpoints_steps=None,  # Disable the default saver
        save_checkpoints_secs=None,  # Disable the default saver
        log_step_count_steps=params["iterations"],
        save_summary_steps=params["iterations"],
        tpu_config=tpu_config.TPUConfig(
            num_shards=mesh_shape.size,
            iterations_per_loop=params["iterations"],
            num_cores_per_replica=1,
            per_host_input_for_training=tpu_config.InputPipelineConfig.
            BROADCAST))

    estimator = tpu_estimator.TPUEstimator(
        use_tpu=params["use_tpu"],
        model_fn=model_fn,
        config=config,
        train_batch_size=params["train_batch_size"],
        eval_batch_size=params["train_batch_size"],
        predict_batch_size=params["predict_batch_size"],
        params=params)

    def _make_task_estimator(task):
        task_params = params.copy()
        task_params["eval_task"] = task
        return tpu_estimator.TPUEstimator(
            use_tpu=params["use_tpu"],
            model_fn=model_fn,
            config=config,
            train_batch_size=params["train_batch_size"],
            eval_batch_size=params["train_batch_size"],
            predict_batch_size=params["predict_batch_size"],
            params=task_params)

    eval_task_estimators = {
        task: _make_task_estimator(task)
        for task in eval_tasks
    }

    current_step = int(
        estimator_lib._load_global_step_from_checkpoint_dir(
            params["model_path"]))
    logger.info(f"Current step {current_step}")

    if args.predict:
        # Predict
        predictions = estimator.predict(input_fn=pred_input_fn)
        logger.info("Predictions generated")
        enc = fetch_encoder(params)
        handle_pred_output_fn(predictions,
                              logger,
                              enc,
                              params,
                              out_name=f"predictions_{current_step}")
        return

    elif has_predict_or_eval_steps_or_eval_tasks:
        # Eval and train - stop and predict and/or eval every checkpoint
        while current_step < params["train_steps"]:
            next_checkpoint = min(current_step + args.steps_per_checkpoint,
                                  params["train_steps"])

            estimator.train(input_fn=partial(input_fn, eval=False),
                            max_steps=next_checkpoint)
            current_step = next_checkpoint

            if params["predict_steps"] > 0:
                logger.info("Running prediction...")
                predictions = estimator.predict(input_fn=pred_input_fn)
                enc = fetch_encoder(params)
                handle_pred_output_fn(predictions,
                                      logger,
                                      enc,
                                      params,
                                      out_name=f"predictions_{current_step}")

            if params["eval_steps"] > 0:
                logger.info("Running evaluation...")
                eval_results = estimator.evaluate(input_fn=partial(input_fn,
                                                                   eval=True),
                                                  steps=params["eval_steps"])
                logger.info(f"Eval results: {eval_results}")

            for task in eval_tasks:
                logger.info(f"Starting evaluation task '{task}'")
                task_info = task_descriptors[task]["get_task_info_fn"](params)
                task_estimator = eval_task_estimators[task]
                task_input_fn = task_descriptors[task]["input_fn"]
                eval_results = task_estimator.evaluate(
                    input_fn=task_input_fn,
                    steps=task_info["n_steps"],
                    name=task)
                logger.info(f"Eval task '{task}' results: {eval_results}")
        return
    else:
        # Else, just train
        while current_step < params["train_steps"]:
            # Else, don't stop and restart
            estimator.train(input_fn=partial(input_fn, eval=False),
                            max_steps=params["train_steps"])