def resume_training(out: ModelDir, notes: str = None, dry_run=False, start_eval=False): """ Resume training an existing model """ train_params = out.get_last_train_params() model = out.get_model() train_data = train_params["data"] evaluators = train_params["evaluators"] params = train_params["train_params"] params.num_epochs = 24 * 3 if isinstance(train_data, PreprocessedData): # TODO don't hard code # of processes train_data.preprocess(6, 1000) latest = tf.train.latest_checkpoint(out.save_dir) if latest is None: raise ValueError("No checkpoint to resume from found in " + out.save_dir) _train(model, train_data, latest, None, False, params, evaluators, out, notes, dry_run, start_eval)
def resume_training(model_to_resume: str, dataset_oversampling: Dict[str, int], checkpoint: Optional[str] = None, epochs: Optional[int] = None): """Resume training on a partially trained model (or finetune an existing model) :param model_to_resume: path to the model directory of the model to resume training :param dataset_oversampling: dictionary mapping dataset names to integer counts of how much to oversample them :param checkpoint: optional string to specify which checkpoint to resume from. Uses the latest if not specified :param epochs: Optional int specifying how many epochs to train for. If not detailed, runs for 24 """ out = ModelDir(model_to_resume) train_params = out.get_last_train_params() evaluators = train_params["evaluators"] params = train_params["train_params"] params.num_epochs = epochs if epochs is not None else 24 model = out.get_model() notes = None dry_run = False data = prepare_data(model, TrainConfig(), dataset_oversampling) if checkpoint is None: checkpoint = tf.train.latest_checkpoint(out.save_dir) _train_async(model=model, data=data, checkpoint=checkpoint, parameter_checkpoint=None, save_start=False, train_params=params, evaluators=evaluators, out=out, notes=notes, dry_run=dry_run, start_eval=False)