コード例 #1
0
ファイル: trainer.py プロジェクト: artiom-zayats/docqa_squad
def resume_training(out: ModelDir,
                    notes: str = None,
                    dry_run=False,
                    start_eval=False):
    """ Resume training an existing model """

    train_params = out.get_last_train_params()
    model = out.get_model()

    train_data = train_params["data"]

    evaluators = train_params["evaluators"]
    params = train_params["train_params"]
    params.num_epochs = 24 * 3

    if isinstance(train_data, PreprocessedData):
        # TODO don't hard code # of processes
        train_data.preprocess(6, 1000)

    latest = tf.train.latest_checkpoint(out.save_dir)
    if latest is None:
        raise ValueError("No checkpoint to resume from found in " +
                         out.save_dir)

    _train(model, train_data, latest, None, False, params, evaluators, out,
           notes, dry_run, start_eval)
コード例 #2
0
def resume_training(model_to_resume: str,
                    dataset_oversampling: Dict[str, int],
                    checkpoint: Optional[str] = None,
                    epochs: Optional[int] = None):
    """Resume training on a partially trained model (or finetune an existing model)


    :param model_to_resume: path to the model directory of the model to resume training
    :param dataset_oversampling: dictionary mapping dataset names to integer counts of how much
       to oversample them
    :param checkpoint: optional string to specify which checkpoint to resume from. Uses the latest
         if not specified
    :param epochs: Optional int specifying how many epochs to train for. If not detailed, runs for 24
    """
    out = ModelDir(model_to_resume)
    train_params = out.get_last_train_params()
    evaluators = train_params["evaluators"]
    params = train_params["train_params"]
    params.num_epochs = epochs if epochs is not None else 24
    model = out.get_model()

    notes = None
    dry_run = False
    data = prepare_data(model, TrainConfig(), dataset_oversampling)
    if checkpoint is None:
        checkpoint = tf.train.latest_checkpoint(out.save_dir)

    _train_async(model=model,
                 data=data,
                 checkpoint=checkpoint,
                 parameter_checkpoint=None,
                 save_start=False,
                 train_params=params,
                 evaluators=evaluators,
                 out=out,
                 notes=notes,
                 dry_run=dry_run,
                 start_eval=False)