예제 #1
0
def _create_batch_from_dir(batch_dir: str):
    path = Path(batch_dir)
    if not path.is_dir():  # pragma: no cover
        raise ValueError("%s is not a directory" % batch_dir)

    if not (path / HEADER_FILE).is_file():  # pragma: no cover
        raise ValueError("missing headers")
    headers = json.loads(open(path / HEADER_FILE).read())

    if not (path / CONFIG_FILE).is_file():  # pragma: no cover
        raise ValueError("missing model param file")

    config = config_from_model_dir(batch_dir)

    # training path can be empty, since we will not need access
    # to training data simply for read-only data generation
    train_path = ""

    batch = Batch(
        checkpoint_dir=batch_dir,
        input_data_path=train_path,
        headers=headers,
        config=config,
    )

    batch.load_validator_from_file()

    return batch
예제 #2
0
def _create_batch_from_dir(batch_dir: str):
    path = Path(batch_dir)
    if not path.is_dir():  # pragma: no cover
        raise ValueError("%s is not a directory" % batch_dir)

    if not (path / HEADER_FILE).is_file():  # pragma: no cover
        raise ValueError("missing headers")
    headers = json.loads(open(path / HEADER_FILE).read())

    if not (path / CONFIG_FILE).is_file():  # pragma: no cover
        raise ValueError("missing model param file")

    config = config_from_model_dir(batch_dir)

    # training path can be empty, since we will not need access
    # to training data simply for read-only data generation
    train_path = ""

    # Wrap the user supplied callback with a _BatchEpochCallback so we have the batch number too.
    if config.epoch_callback is not None:
        batch_count = int(Path(batch_dir).name.split("_")[-1])
        config.epoch_callback = _BatchEpochCallback(config.epoch_callback,
                                                    batch_count).callback

    batch = Batch(
        checkpoint_dir=batch_dir,
        input_data_path=train_path,
        headers=headers,
        config=config,
    )

    batch.load_validator_from_file()

    return batch
예제 #3
0
def test_load_legacy_config(model_name, dp, expected_learning_rate):
    legacy_model_dir = test_data_dir / '0.14.x' / model_name

    config = config_from_model_dir(legacy_model_dir)

    assert isinstance(config, TensorFlowConfig)
    assert 'dp_learning_rate' not in config.__dict__
    assert config.learning_rate == expected_learning_rate
    assert config.dp == dp
예제 #4
0
 def _generate(
     self, model_dir: Path, count: int, file_name: str, seed, validator
 ) -> str:
     batch_mode = is_model_dir_batch_mode(model_dir)
     if batch_mode:
         if seed is not None and not isinstance(seed, dict):
             raise TypeError("Seed must be a dict in batch mode")
         out_fname = f"{file_name}.csv"
         batcher = DataFrameBatch(mode="read", checkpoint_dir=str(model_dir))
         batcher.generate_all_batch_lines(
             num_lines=count,
             max_invalid=max(count, MAX_INVALID),
             parallelism=1,
             seed_fields=seed
         )
         out_df = batcher.batches_to_df()
         out_df.to_csv(out_fname, index=False)
         return out_fname
     else:
         out = []
         # The model data will be inside of a single directory when a simple model is used. If it
         # was archived correctly, there should only be a single directory inside the archive
         actual_dir = next(model_dir.glob("*"))
         config = config_from_model_dir(actual_dir)
         if seed is not None and not isinstance(seed, str):
             raise TypeError("seed must be a string")
         for data in generate_text(
             config,
             num_lines=count,
             line_validator=validator,
             max_invalid=max(count, MAX_INVALID),
             parallelism=1,
             start_string=seed
         ):
             if data.valid or data.valid is None:
                 out.append(data.text)
         out_fname = file_name + ".txt"
         with open(out_fname, "w") as fout:
             for line in out:
                 fout.write(line + "\n")
         return out_fname