コード例 #1
0
 def __init__(self, config: Config, dataset_label: str) -> None:
     self.batch_size = config["batch_size"]
     input_filename = config.filename(dataset_label + "_inputs")
     label_filename = config.filename(dataset_label + "_labels")
     self.input_matrix = self._load_data(input_filename)
     self.label_matrix = self._load_data(label_filename)
     self.input_dim = self.input_matrix.shape[1]
コード例 #2
0
def test_update_single_setting(default_config, write_yaml):
    filename = write_yaml({"fingerprint_len": 10})

    config = Config(filename)

    assert config["fingerprint_len"] == 10
    assert config["template_occurance"] == default_config["template_occurance"]
コード例 #3
0
def test_preprocess_rollout(write_yaml, shared_datadir, add_cli_arguments):
    config_path = write_yaml(
        {
            "file_prefix": str(shared_datadir / "dummy"),
            "split_size": {"training": 0.6, "testing": 0.2, "validation": 0.2},
        }
    )
    add_cli_arguments(config_path)

    rollout_main()

    with open(shared_datadir / "dummy_template_library.csv", "r") as fileobj:
        lines = fileobj.read().splitlines()
    assert len(lines) == 10

    with open(shared_datadir / "dummy_training.csv", "r") as fileobj:
        lines = fileobj.read().splitlines()
    assert len(lines) == 6

    with open(shared_datadir / "dummy_testing.csv", "r") as fileobj:
        lines = fileobj.read().splitlines()
    assert len(lines) == 2

    with open(shared_datadir / "dummy_validation.csv", "r") as fileobj:
        lines = fileobj.read().splitlines()
    assert len(lines) == 2

    data = pd.read_hdf(shared_datadir / "dummy_unique_templates.hdf5", "table")
    config = Config(config_path)
    assert len(data) == 2
    assert "retro_template" in data.columns
    assert "library_occurence" in data.columns
    for column in config["metadata_headers"]:
        assert column in data.columns
コード例 #4
0
def test_update_invalid_setting(default_config, write_yaml):
    filename = write_yaml(
        {"fingerprint_len": {"training": 0.8, "testing": 0.1, "validation": 0.1}}
    )

    config = Config(filename)

    assert config["fingerprint_len"] == default_config["fingerprint_len"]
コード例 #5
0
def _get_config() -> Config:
    parser = argparse.ArgumentParser(
        "Tool to pre-process a template library to be used in training a expansion network policy"
    )
    parser.add_argument("config", help="the filename to a configuration file")
    args = parser.parse_args()

    return Config(args.config)
コード例 #6
0
def _get_config():
    parser = argparse.ArgumentParser(
        "Tool to pre-process a template library to be used to train a recommender network"
    )
    parser.add_argument("config", help="the filename to a configuration file")
    args = parser.parse_args()

    return Config(args.config)
コード例 #7
0
def main():
    """ Entry-point for the aizynth_training tool
    """
    parser = argparse.ArgumentParser("Tool to train a network policy")
    parser.add_argument("config", help="the filename to a configuration file")
    args = parser.parse_args()

    config = Config(args.config)
    train_rollout_keras_model(config)
コード例 #8
0
def test_update_nested_setting(default_config, write_yaml):
    filename = write_yaml(
        {"split_size": {"training": 0.8, "testing": 0.1, "validation": 0.1}}
    )

    config = Config(filename)

    assert config["template_occurance"] == default_config["template_occurance"]
    assert config["split_size"]["training"] == 0.8
    assert config["split_size"]["testing"] == 0.1
    assert config["split_size"]["validation"] == 0.1
コード例 #9
0
def _save_unique_templates(dataset: pd.DataFrame, config: Config) -> None:
    template_group = dataset.groupby("template_hash", sort=False).size()
    dataset = dataset[["retro_template", "template_code"] +
                      config["metadata_headers"]]
    if "classification" in dataset.columns:
        dataset["classification"].fillna("-", inplace=True)
    dataset = dataset.drop_duplicates(subset="template_code", keep="first")
    dataset["library_occurence"] = template_group.values
    dataset.set_index("template_code", inplace=True)
    dataset = dataset.sort_index()
    dataset.to_hdf(config.filename("unique_templates"), "table")
コード例 #10
0
def _get_config():
    parser = argparse.ArgumentParser(
        "Tool to generate artifical negative reactions")
    parser.add_argument("config", help="the filename to a configuration file")
    parser.add_argument(
        "method",
        choices=["strict", "random", "recommender"],
        help="the method to create random data",
    )
    args = parser.parse_args()

    return Config(args.config), args.method
コード例 #11
0
def _filter_dataset(config: Config) -> pd.DataFrame:

    filename = config.filename("raw_library")
    if not os.path.exists(filename):
        raise FileNotFoundError(
            f"The file {filename} is missing - cannot proceed without the full template library."
        )

    # Skipping the last header as it is not available in the raw data
    full_data = pd.read_csv(
        filename,
        index_col=False,
        header=None,
        names=config["library_headers"][:-1],
    )

    if config["remove_unsanitizable_products"]:
        products = full_data["products"].to_numpy()
        idx = np.apply_along_axis(is_sanitizable, 0, [products])
        full_data = full_data[idx]

    full_data = full_data.drop_duplicates(subset="reaction_hash")
    template_group = full_data.groupby("template_hash")
    template_group = template_group.size().sort_values(ascending=False)
    min_index = template_group[
        template_group >= config["template_occurrence"]].index
    dataset = full_data[full_data["template_hash"].isin(min_index)]

    template_labels = LabelEncoder()
    dataset = dataset.assign(
        template_code=template_labels.fit_transform(dataset["template_hash"]))
    dataset.to_csv(
        config.filename("library"),
        mode="w",
        header=False,
        index=False,
    )
    return dataset
コード例 #12
0
ファイル: test_cli.py プロジェクト: wangxr0526/aizynthfinder
def test_preprocess_expansion_no_class(write_yaml, shared_datadir,
                                       add_cli_arguments):
    config_path = write_yaml({
        "library_headers": [
            "index",
            "ID",
            "reaction_hash",
            "reactants",
            "products",
            "retro_template",
            "template_hash",
        ],
        "metadata_headers": ["template_hash"],
        "file_prefix":
        str(shared_datadir / "dummy_noclass"),
        "split_size": {
            "training": 0.6,
            "testing": 0.2,
            "validation": 0.2
        },
    })
    add_cli_arguments(config_path)

    expansion_main()

    with open(shared_datadir / "dummy_noclass_template_library.csv",
              "r") as fileobj:
        lines = fileobj.read().splitlines()
    assert len(lines) == 10

    with open(shared_datadir / "dummy_noclass_training.csv", "r") as fileobj:
        lines = fileobj.read().splitlines()
    assert len(lines) == 6

    with open(shared_datadir / "dummy_noclass_testing.csv", "r") as fileobj:
        lines = fileobj.read().splitlines()
    assert len(lines) == 2

    with open(shared_datadir / "dummy_noclass_validation.csv", "r") as fileobj:
        lines = fileobj.read().splitlines()
    assert len(lines) == 2

    data = pd.read_hdf(shared_datadir / "dummy_noclass_unique_templates.hdf5",
                       "table")
    config = Config(config_path)
    assert len(data) == 2
    assert "retro_template" in data.columns
    assert "library_occurrence" in data.columns
    for column in config["metadata_headers"]:
        assert column in data.columns
コード例 #13
0
def main() -> None:
    """Entry-point for the aizynth_training tool"""
    parser = argparse.ArgumentParser("Tool to train a network policy")
    parser.add_argument("config", help="the filename to a configuration file")
    parser.add_argument(
        "model",
        choices=["expansion", "filter", "recommender"],
        help="the model to train",
    )
    args = parser.parse_args()

    config = Config(args.config)
    if args.model == "expansion":
        train_expansion_keras_model(config)
    elif args.model == "filter":
        train_filter_keras_model(config)
    elif args.model == "recommender":
        train_recommender_keras_model(config)
コード例 #14
0
def _setup_callbacks(config: Config) -> List[Any]:
    early_stopping = EarlyStopping(monitor="val_loss", patience=10)
    csv_logger = CSVLogger(config.filename("_keras_training.log"), append=True)

    checkpoint_path = os.path.join(config["output_path"], "checkpoints")
    if not os.path.exists(checkpoint_path):
        os.mkdir(checkpoint_path)
    checkpoint = ModelCheckpoint(
        os.path.join(checkpoint_path, "keras_model.hdf5"),
        monitor="loss",
        save_best_only=True,
    )

    reduce_lr = ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=5,
        verbose=0,
        mode="auto",
        min_delta=0.000001,
        cooldown=0,
        min_lr=0,
    )
    return [early_stopping, csv_logger, checkpoint, reduce_lr]
コード例 #15
0
 def __init__(self, config: Config, dataset_label: str) -> None:
     super().__init__(config, dataset_label)
     filename = config.filename(dataset_label + "_inputs2")
     self.input_matrix2 = self._load_data(filename)
コード例 #16
0
def default_config():
    return Config()
コード例 #17
0
def _save_unique_templates(dataset: pd.DataFrame, config: Config) -> None:
    dataset = dataset[["retro_template", "template_code"]]
    dataset = dataset.drop_duplicates(subset="template_code", keep="first")
    dataset.set_index("template_code", inplace=True)
    dataset = dataset.sort_index()
    dataset.to_hdf(config.filename("unique_templates"), "table")
コード例 #18
0
def test_empty_config(default_config, write_yaml):
    filename = write_yaml({})

    config = Config(filename)

    assert config._config == default_config._config