def auto_neb(m1,
             m2,
             graph: MultiGraph,
             model: models.ModelWrapper,
             config: config.AutoNEBConfig,
             callback: callable = None):
    # Continue existing cycles or start from scratch
    if m2 in graph[m1]:
        existing_edges = graph[m1][m2]
        previous_cycle_idx = max(existing_edges)
        connection_data = graph[m1][m2][previous_cycle_idx]
        start_cycle_idx = previous_cycle_idx + 1
    else:
        connection_data = {
            "path_coords":
            torch.cat([graph.nodes[m]["coords"].view(1, -1)
                       for m in (m1, m2)]),
            "target_distances":
            torch.ones(1)
        }
        start_cycle_idx = 1
    assert start_cycle_idx <= config.cycle_count

    # Run NEB and add to graph
    for cycle_idx in helper.pbar(
            range(start_cycle_idx, config.cycle_count + 1), "AutoNEB"):
        cycle_config = config.neb_configs[cycle_idx - 1]
        connection_data = neb(connection_data, model, cycle_config)
        graph.add_edge(m1,
                       m2,
                       key=cycle_idx,
                       **helper.move_to(connection_data, "cpu"))
        if callback is not None:
            callback()
Beispiel #2
0
def main():
    parser = ArgumentParser()
    parser.add_argument("project_directory", nargs=1)
    parser.add_argument("config_file", nargs=1)
    parser.add_argument("--no-backup", default=False, action="store_true")
    args = parser.parse_args()

    project_directory = args.project_directory[0]
    config_file = args.config_file[0]
    graph_path, project_config_path = setup_project(config_file,
                                                    project_directory)

    model, minima_count, min_config, lex_config = read_config_file(
        project_config_path)

    # Setup Logger
    root_logger = getLogger()
    root_logger.setLevel(INFO)
    root_logger.addHandler(StreamHandler(sys.stdout))
    root_logger.addHandler(
        FileHandler(join(project_directory, "exploration.log")))

    # === Create/load graph ===
    if isfile(graph_path):
        if not args.no_backup:
            root_logger.info("Copying current 'graph.p' to backup file.")
            copyfile(
                graph_path,
                graph_path.replace(".p", f"_bak{strftime('%Y%m%d-%H%M')}.p"))
        else:
            root_logger.info(
                "Not creating a backup of 'graph.p' because of user request.")
        graph = repair_graph(load_pickle_graph(graph_path), model)
    else:
        graph = MultiGraph()

    # Call this after every optmisiation
    def save_callback():
        store_pickle_graph(graph, graph_path)

    # === Ensure the specified number of minima ===
    for _ in pbar(range(len(graph.nodes), minima_count), "Finding minima"):
        minimum_data = find_minimum(model, min_config)
        graph.add_node(
            max(graph.nodes) + 1 if len(graph.nodes) > 0 else 1,
            **move_to(minimum_data, "cpu"))
        save_callback()

    # === Connect minima ===
    landscape_exploration(graph, model, lex_config, callback=save_callback)