Пример #1
0
def ocp_trainable(config, checkpoint_dir=None):
    setup_imports()
    # update config for PBT learning rate
    config["optim"].update(lr_initial=config["lr"])
    # trainer defaults are changed to run HPO
    trainer = registry.get_trainer_class(config.get("trainer", "simple"))(
        task=config["task"],
        model=config["model"],
        dataset=config["dataset"],
        optimizer=config["optim"],
        identifier=config["identifier"],
        run_dir=config.get("run_dir", "./"),
        is_debug=config.get("is_debug", False),
        is_vis=config.get("is_vis", False),
        is_hpo=config.get("is_hpo", True),  # hpo
        print_every=config.get("print_every", 10),
        seed=config.get("seed", 0),
        logger=config.get("logger", None),  # hpo
        local_rank=config["local_rank"],
        amp=config.get("amp", False),
        cpu=config.get("cpu", False),
    )
    # add checkpoint here
    if checkpoint_dir:
        checkpoint = os.path.join(checkpoint_dir, "checkpoint")
        trainer.load_pretrained(checkpoint)
    # set learning rate
    for g in trainer.optimizer.param_groups:
        g["lr"] = config["lr"]
    # start training
    trainer.train()
Пример #2
0
def main(config):
    if args.distributed:
        distutils.setup(config)

    try:
        setup_imports()
        trainer = registry.get_trainer_class(config.get("trainer", "simple"))(
            task=config["task"],
            model=config["model"],
            dataset=config["dataset"],
            optimizer=config["optim"],
            identifier=config["identifier"],
            run_dir=config.get("run_dir", "./"),
            is_debug=config.get("is_debug", False),
            is_vis=config.get("is_vis", False),
            print_every=config.get("print_every", 10),
            seed=config.get("seed", 0),
            logger=config.get("logger", "tensorboard"),
            local_rank=config["local_rank"],
            amp=config.get("amp", False),
            cpu=config.get("cpu", False),
        )
        if config["checkpoint"] is not None:
            trainer.load_pretrained(config["checkpoint"])

        start_time = time.time()

        if config["mode"] == "train":
            trainer.train()

        elif config["mode"] == "predict":
            assert (
                trainer.test_loader
                is not None), "Test dataset is required for making predictions"
            assert config["checkpoint"]
            results_file = "predictions"
            trainer.predict(
                trainer.test_loader,
                results_file=results_file,
                disable_tqdm=False,
            )

        elif config["mode"] == "run-relaxations":
            assert isinstance(
                trainer, ForcesTrainer
            ), "Relaxations are only possible for ForcesTrainer"
            assert (trainer.relax_dataset is not None
                    ), "Relax dataset is required for making predictions"
            assert config["checkpoint"]
            trainer.run_relaxations()

        distutils.synchronize()

        if distutils.is_master():
            print("Total time taken = ", time.time() - start_time)

    finally:
        if args.distributed:
            distutils.cleanup()
Пример #3
0
    def __init__(self,
                 config_yml,
                 checkpoint=None,
                 cutoff=6,
                 max_neighbors=50):
        """
        OCP-ASE Calculator

        Args:
            config_yml (str):
                Path to yaml config.
            checkpoint (str):
                Path to trained checkpoint.
            cutoff (int):
                Cutoff radius to be used for data preprocessing.
            max_neighbors (int):
                Maximum amount of neighbors to store for a given atom.
        """
        setup_imports()
        setup_logging()
        Calculator.__init__(self)

        config = yaml.safe_load(open(config_yml, "r"))
        if "includes" in config:
            for include in config["includes"]:
                include_config = yaml.safe_load(open(include, "r"))
                config.update(include_config)

        # Save config so obj can be transported over network (pkl)
        self.config = copy.deepcopy(config)
        self.config["checkpoint"] = checkpoint

        self.trainer = registry.get_trainer_class(
            config.get("trainer", "simple"))(
                task=config["task"],
                model=config["model"],
                dataset=config["dataset"],
                optimizer=config["optim"],
                identifier="",
                slurm=config.get("slurm", {}),
                local_rank=config.get("local_rank", 0),
                is_debug=config.get("is_debug", True),
                cpu=True,
            )

        if checkpoint is not None:
            self.load_checkpoint(checkpoint)

        self.a2g = AtomsToGraphs(
            max_neigh=max_neighbors,
            radius=cutoff,
            r_energy=False,
            r_forces=False,
            r_distances=False,
        )
Пример #4
0
    def _setup(self, config):
        self.trainer = registry.get_trainer_class("simple")(
            task=config["task"],
            model=config["model"],
            dataset=config["dataset"],
            optimizer=config["optim"],
            identifier="",
            is_debug=True,
            seed=0,
            logger=None,
        )

        print("Device = {dev}".format(dev=self.trainer.device))
Пример #5
0
def main(config):
    setup_imports()
    trainer = registry.get_trainer_class(config.get("trainer", "simple"))(
        task=config["task"],
        model=config["model"],
        dataset=config["dataset"],
        optimizer=config["optim"],
        identifier=config["identifier"],
        run_dir=config.get("run_dir", "./"),
        is_debug=config.get("is_debug", False),
        is_vis=config.get("is_vis", False),
        print_every=config.get("print_every", 10),
        seed=config.get("seed", 0),
        logger=config.get("logger", "tensorboard"),
        local_rank=config["local_rank"],
        amp=config.get("amp", False),
    )
    import time
    start_time = time.time()
    trainer.train()
    distutils.synchronize()
    print('Time = ', time.time() - start_time)
Пример #6
0
Файл: main.py Проект: wood-b/ocp
    def __call__(self, config):
        setup_logging()
        self.config = copy.deepcopy(config)

        if args.distributed:
            distutils.setup(config)

        try:
            setup_imports()
            self.trainer = registry.get_trainer_class(
                config.get("trainer", "simple"))(
                    task=config["task"],
                    model=config["model"],
                    dataset=config["dataset"],
                    optimizer=config["optim"],
                    identifier=config["identifier"],
                    timestamp_id=config.get("timestamp_id", None),
                    run_dir=config.get("run_dir", "./"),
                    is_debug=config.get("is_debug", False),
                    is_vis=config.get("is_vis", False),
                    print_every=config.get("print_every", 10),
                    seed=config.get("seed", 0),
                    logger=config.get("logger", "tensorboard"),
                    local_rank=config["local_rank"],
                    amp=config.get("amp", False),
                    cpu=config.get("cpu", False),
                    slurm=config.get("slurm", {}),
                )
            self.task = registry.get_task_class(config["mode"])(self.config)
            self.task.setup(self.trainer)
            start_time = time.time()
            self.task.run()
            distutils.synchronize()
            if distutils.is_master():
                logging.info(f"Total time taken: {time.time() - start_time}")
        finally:
            if args.distributed:
                distutils.cleanup()
Пример #7
0
            logging.info("Exiting script")
            sys.exit()
    else:
        initialize_scale_file(scale_file)

    AutomaticFit.set2fitmode()

    trainer = registry.get_trainer_class(config.get("trainer", "simple"))(
        task=config["task"],
        model=config["model"],
        dataset=config["dataset"],
        optimizer=config["optim"],
        identifier=config["identifier"],
        run_dir=config.get("run_dir", "./"),
        is_debug=config.get("is_debug", False),
        is_vis=config.get("is_vis", False),
        print_every=config.get("print_every", 10),
        seed=config.get("seed", 0),
        logger=config.get("logger", "tensorboard"),
        local_rank=config["local_rank"],
        amp=config.get("amp", False),
        cpu=config.get("cpu", False),
        slurm=config.get("slurm", {}),
    )

    # Fitting loop
    logging.info("Start fitting")

    if not AutomaticFit.fitting_completed():
        with torch.no_grad():
            trainer.model.eval()
Пример #8
0
def oc20_initialize(model_name, gpu=True):
    """
    Initialize GNNP of OC20 (i.e. S2EF).
    Args:
        model_name (str): name of model for GNNP. One can use the followings,
            - "DimeNet++"
            - "GemNet-dT"
            - "CGCNN"
            - "SchNet"
            - "SpinConv"
        gpu (bool): using GPU, if possible.
    Returns:
        cutoff: cutoff radius.
    """

    setup_imports()
    setup_logging()

    # Check model_name
    log_file = open("log.oc20", "w")
    log_file.write("\n")
    log_file.write("model_name = " + model_name + "\n")

    if model_name is not None:
        model_name = model_name.lower()

    if model_name == "DimeNet++".lower():
        config_yml = "dimenetpp.yml"
        checkpoint = "dimenetpp_all.pt"

    elif model_name == "GemNet-dT".lower():
        config_yml = "gemnet.yml"
        checkpoint = "gemnet_t_direct_h512_all.pt"

    elif model_name == "CGCNN".lower():
        config_yml = "cgcnn.yml"
        checkpoint = "cgcnn_all.pt"

    elif model_name == "SchNet".lower():
        config_yml = "schnet.yml"
        checkpoint = "schnet_all_large.pt"

    elif model_name == "SpinConv".lower():
        config_yml = "spinconv.yml"
        checkpoint = "spinconv_force_centric_all.pt"

    else:
        raise Exception("incorrect model_name.")

    basePath = os.path.dirname(os.path.abspath(__file__))
    config_dir = os.path.normpath(os.path.join(basePath, "oc20_configs"))
    chekpt_dir = os.path.normpath(os.path.join(basePath, "oc20_checkpt"))
    config_yml = os.path.normpath(os.path.join(config_dir, config_yml))
    checkpoint = os.path.normpath(os.path.join(chekpt_dir, checkpoint))

    log_file.write("config_yml = " + config_yml + "\n")
    log_file.write("checkpoint = " + checkpoint + "\n")

    # Check gpu
    gpu_ = (gpu and torch.cuda.is_available())

    log_file.write("gpu (in)   = " + str(gpu) + "\n")
    log_file.write("gpu (eff)  = " + str(gpu_) + "\n")

    # Load configuration
    config = yaml.safe_load(open(config_yml, "r"))

    # Check max_neigh and cutoff
    max_neigh = config["model"].get("max_neighbors", 50)
    cutoff = config["model"].get("cutoff", 6.0)

    log_file.write("max_neigh  = " + str(max_neigh) + "\n")
    log_file.write("cutoff     = " + str(cutoff) + "\n")

    assert max_neigh > 0
    assert cutoff > 0.0

    # To calculate the edge indices on-the-fly
    config["model"]["otf_graph"] = True

    # Modify path of scale_file for GemNet-dT
    scale_file = config["model"].get("scale_file", None)

    if scale_file is not None:
        scale_file = os.path.normpath(os.path.join(config_dir, scale_file))
        config["model"]["scale_file"] = scale_file

    log_file.write("\nconfig:\n")
    log_file.write(pprint.pformat(config) + "\n")
    log_file.write("\n")
    log_file.close()

    # Create trainer, that is pre-trained
    global myTrainer

    myTrainer = registry.get_trainer_class(config.get("trainer", "forces"))(
        task=config["task"],
        model=config["model"],
        dataset=None,
        normalizer=config["normalizer"],
        optimizer=config["optim"],
        identifier="",
        slurm=config.get("slurm", {}),
        local_rank=config.get("local_rank", 0),
        is_debug=config.get("is_debug", True),
        cpu=not gpu_)

    # Load checkpoint
    myTrainer.load_checkpoint(checkpoint)

    # Atoms object of ASE, that is empty here
    global myAtoms

    myAtoms = None

    # Converter: Atoms -> Graphs (the edges on-the-fly)
    global myA2G

    myA2G = AtomsToGraphs(max_neigh=max_neigh,
                          radius=cutoff,
                          r_energy=False,
                          r_forces=False,
                          r_distances=False,
                          r_edges=False,
                          r_fixed=False)

    return cutoff