def create_serialization_dir(params: Params, serialization_dir: str, recover: bool) -> None: """ This function creates the serialization directory if it doesn't exist. If it already exists, then it verifies that we're recovering from a training with an identical configuration. Parameters ---------- params: ``Params`` A parameter object specifying an AllenNLP Experiment. serialization_dir: ``str`` The directory in which to save results and logs. recover: ``bool`` If ``True``, we will try to recover from an existing serialization directory, and crash if the directory doesn't exist, or doesn't match the configuration we're given. """ if os.path.exists(serialization_dir): if serialization_dir == '/output': # Special-casing the beaker output directory, which will already exist when training # starts. return if not recover: raise ConfigurationError(f"Serialization directory ({serialization_dir}) already exists. " f"Specify --recover to recover training from existing output.") logger.info(f"Recovering from prior training at {serialization_dir}.") recovered_config_file = os.path.join(serialization_dir, CONFIG_NAME) if not os.path.exists(recovered_config_file): raise ConfigurationError("The serialization directory already exists but doesn't " "contain a config.json. You probably gave the wrong directory.") else: loaded_params = Params.from_file(recovered_config_file) # Check whether any of the training configuration differs from the configuration we are # resuming. If so, warn the user that training may fail. fail = False flat_params = params.as_flat_dict() flat_loaded = loaded_params.as_flat_dict() for key in flat_params.keys() - flat_loaded.keys(): logger.error(f"Key '{key}' found in training configuration but not in the serialization " f"directory we're recovering from.") fail = True for key in flat_loaded.keys() - flat_params.keys(): logger.error(f"Key '{key}' found in the serialization directory we're recovering from " f"but not in the training config.") fail = True for key in flat_params.keys(): if flat_params.get(key, None) != flat_loaded.get(key, None): logger.error(f"Value for '{key}' in training configuration does not match that the value in " f"the serialization directory we're recovering from: " f"{flat_params[key]} != {flat_loaded[key]}") fail = True if fail: raise ConfigurationError("Training configuration does not match the configuration we're " "recovering from.") else: if recover: raise ConfigurationError(f"--recover specified but serialization_dir ({serialization_dir}) " "does not exist. There is nothing to recover from.") os.makedirs(serialization_dir)
def create_serialization_dir(params: Params, serialization_dir: str) -> None: """ This function creates the serialization directory if it doesn't exist. If it already exists, then it verifies that we're recovering from a training with an identical configuration. Parameters ---------- params: Params, required. A parameter object specifying an AllenNLP Experiment. serialization_dir: str, required The directory in which to save results and logs. """ if os.path.exists(serialization_dir): logger.info(f"Recovering from prior training at {serialization_dir}.") recovered_config_file = os.path.join(serialization_dir, CONFIG_NAME) if not os.path.exists(recovered_config_file): raise ConfigurationError( "The serialization directory already exists but doesn't " "contain a config.json. You probably gave the wrong directory." ) else: loaded_params = Params.from_file(recovered_config_file) # Check whether any of the training configuration differs from the configuration we are resuming. # If so, warn the user that training may fail. fail = False flat_params = params.as_flat_dict() flat_loaded = loaded_params.as_flat_dict() for key in flat_params.keys() - flat_loaded.keys(): logger.error( f"Key '{key}' found in training configuration but not in the serialization " f"directory we're recovering from.") fail = True for key in flat_loaded.keys() - flat_params.keys(): logger.error( f"Key '{key}' found in the serialization directory we're recovering from " f"but not in the training config.") fail = True for key in flat_params.keys(): if flat_params.get(key, None) != flat_loaded.get(key, None): logger.error( f"Value for '{key}' in training configuration does not match that the value in " f"the serialization directory we're recovering from: " f"{flat_params[key]} != {flat_loaded[key]}") fail = True if fail: raise ConfigurationError( "Training configuration does not match the configuration we're " "recovering from.") else: os.makedirs(serialization_dir)