Esempio n. 1
0
def main(config):
    if args.distributed:
        distutils.setup(config)

    try:
        setup_imports()
        trainer = registry.get_trainer_class(config.get("trainer", "simple"))(
            task=config["task"],
            model=config["model"],
            dataset=config["dataset"],
            optimizer=config["optim"],
            identifier=config["identifier"],
            run_dir=config.get("run_dir", "./"),
            is_debug=config.get("is_debug", False),
            is_vis=config.get("is_vis", False),
            print_every=config.get("print_every", 10),
            seed=config.get("seed", 0),
            logger=config.get("logger", "tensorboard"),
            local_rank=config["local_rank"],
            amp=config.get("amp", False),
            cpu=config.get("cpu", False),
        )
        if config["checkpoint"] is not None:
            trainer.load_pretrained(config["checkpoint"])

        start_time = time.time()

        if config["mode"] == "train":
            trainer.train()

        elif config["mode"] == "predict":
            assert (
                trainer.test_loader
                is not None), "Test dataset is required for making predictions"
            assert config["checkpoint"]
            results_file = "predictions"
            trainer.predict(
                trainer.test_loader,
                results_file=results_file,
                disable_tqdm=False,
            )

        elif config["mode"] == "run-relaxations":
            assert isinstance(
                trainer, ForcesTrainer
            ), "Relaxations are only possible for ForcesTrainer"
            assert (trainer.relax_dataset is not None
                    ), "Relax dataset is required for making predictions"
            assert config["checkpoint"]
            trainer.run_relaxations()

        distutils.synchronize()

        if distutils.is_master():
            print("Total time taken = ", time.time() - start_time)

    finally:
        if args.distributed:
            distutils.cleanup()
Esempio n. 2
0
    def save_results(self, predictions, results_file, keys):
        if results_file is None:
            return

        results_file_path = os.path.join(
            self.config["cmd"]["results_dir"],
            f"{self.name}_{results_file}_{distutils.get_rank()}.npz",
        )
        np.savez_compressed(
            results_file_path,
            ids=predictions["id"],
            **{key: predictions[key] for key in keys},
        )

        distutils.synchronize()
        if distutils.is_master():
            gather_results = defaultdict(list)
            full_path = os.path.join(
                self.config["cmd"]["results_dir"],
                f"{self.name}_{results_file}.npz",
            )

            for i in range(distutils.get_world_size()):
                rank_path = os.path.join(
                    self.config["cmd"]["results_dir"],
                    f"{self.name}_{results_file}_{i}.npz",
                )
                rank_results = np.load(rank_path, allow_pickle=True)
                gather_results["ids"].extend(rank_results["ids"])
                for key in keys:
                    gather_results[key].extend(rank_results[key])
                os.remove(rank_path)

            # Because of how distributed sampler works, some system ids
            # might be repeated to make no. of samples even across GPUs.
            _, idx = np.unique(gather_results["ids"], return_index=True)
            gather_results["ids"] = np.array(gather_results["ids"])[idx]
            for k in keys:
                if k == "forces":
                    gather_results[k] = np.concatenate(
                        np.array(gather_results[k])[idx]
                    )
                elif k == "chunk_idx":
                    gather_results[k] = np.cumsum(
                        np.array(gather_results[k])[idx]
                    )[:-1]
                else:
                    gather_results[k] = np.array(gather_results[k])[idx]

            logging.info(f"Writing results to {full_path}")
            np.savez_compressed(full_path, **gather_results)
Esempio n. 3
0
def main(config):
    setup_imports()
    trainer = registry.get_trainer_class(config.get("trainer", "simple"))(
        task=config["task"],
        model=config["model"],
        dataset=config["dataset"],
        optimizer=config["optim"],
        identifier=config["identifier"],
        run_dir=config.get("run_dir", "./"),
        is_debug=config.get("is_debug", False),
        is_vis=config.get("is_vis", False),
        print_every=config.get("print_every", 10),
        seed=config.get("seed", 0),
        logger=config.get("logger", "tensorboard"),
        local_rank=config["local_rank"],
        amp=config.get("amp", False),
    )
    import time
    start_time = time.time()
    trainer.train()
    distutils.synchronize()
    print('Time = ', time.time() - start_time)
Esempio n. 4
0
File: main.py Progetto: wood-b/ocp
    def __call__(self, config):
        setup_logging()
        self.config = copy.deepcopy(config)

        if args.distributed:
            distutils.setup(config)

        try:
            setup_imports()
            self.trainer = registry.get_trainer_class(
                config.get("trainer", "simple"))(
                    task=config["task"],
                    model=config["model"],
                    dataset=config["dataset"],
                    optimizer=config["optim"],
                    identifier=config["identifier"],
                    timestamp_id=config.get("timestamp_id", None),
                    run_dir=config.get("run_dir", "./"),
                    is_debug=config.get("is_debug", False),
                    is_vis=config.get("is_vis", False),
                    print_every=config.get("print_every", 10),
                    seed=config.get("seed", 0),
                    logger=config.get("logger", "tensorboard"),
                    local_rank=config["local_rank"],
                    amp=config.get("amp", False),
                    cpu=config.get("cpu", False),
                    slurm=config.get("slurm", {}),
                )
            self.task = registry.get_task_class(config["mode"])(self.config)
            self.task.setup(self.trainer)
            start_time = time.time()
            self.task.run()
            distutils.synchronize()
            if distutils.is_master():
                logging.info(f"Total time taken: {time.time() - start_time}")
        finally:
            if args.distributed:
                distutils.cleanup()
Esempio n. 5
0
    def run_relaxations(self, split="val", epoch=None):
        print("### Running ML-relaxations")
        self.model.eval()

        evaluator, metrics = Evaluator(task="is2rs"), {}

        if hasattr(self.relax_dataset[0], "pos_relaxed") and hasattr(
                self.relax_dataset[0], "y_relaxed"):
            split = "val"
        else:
            split = "test"

        ids = []
        relaxed_positions = []
        for i, batch in tqdm(enumerate(self.relax_loader),
                             total=len(self.relax_loader)):
            relaxed_batch = ml_relax(
                batch=batch,
                model=self,
                steps=self.config["task"].get("relaxation_steps", 200),
                fmax=self.config["task"].get("relaxation_fmax", 0.0),
                relax_opt=self.config["task"]["relax_opt"],
                device=self.device,
                transform=None,
            )

            if self.config["task"].get("write_pos", False):
                systemids = [str(i) for i in relaxed_batch.sid.tolist()]
                natoms = relaxed_batch.natoms.tolist()
                positions = torch.split(relaxed_batch.pos, natoms)
                batch_relaxed_positions = [pos.tolist() for pos in positions]

                relaxed_positions += batch_relaxed_positions
                ids += systemids

            if split == "val":
                mask = relaxed_batch.fixed == 0
                s_idx = 0
                natoms_free = []
                for natoms in relaxed_batch.natoms:
                    natoms_free.append(
                        torch.sum(mask[s_idx:s_idx + natoms]).item())
                    s_idx += natoms

                target = {
                    "energy": relaxed_batch.y_relaxed,
                    "positions": relaxed_batch.pos_relaxed[mask],
                    "cell": relaxed_batch.cell,
                    "pbc": torch.tensor([True, True, True]),
                    "natoms": torch.LongTensor(natoms_free),
                }

                prediction = {
                    "energy": relaxed_batch.y,
                    "positions": relaxed_batch.pos[mask],
                    "cell": relaxed_batch.cell,
                    "pbc": torch.tensor([True, True, True]),
                    "natoms": torch.LongTensor(natoms_free),
                }

                metrics = evaluator.eval(prediction, target, metrics)

        if self.config["task"].get("write_pos", False):
            rank = distutils.get_rank()
            pos_filename = os.path.join(self.config["cmd"]["results_dir"],
                                        f"relaxed_pos_{rank}.npz")
            np.savez_compressed(
                pos_filename,
                ids=ids,
                pos=np.array(relaxed_positions, dtype=object),
            )

            distutils.synchronize()
            if distutils.is_master():
                gather_results = defaultdict(list)
                full_path = os.path.join(
                    self.config["cmd"]["results_dir"],
                    "relaxed_positions.npz",
                )

                for i in range(distutils.get_world_size()):
                    rank_path = os.path.join(
                        self.config["cmd"]["results_dir"],
                        f"relaxed_pos_{i}.npz",
                    )
                    rank_results = np.load(rank_path, allow_pickle=True)
                    gather_results["ids"].extend(rank_results["ids"])
                    gather_results["pos"].extend(rank_results["pos"])
                    os.remove(rank_path)

                # Because of how distributed sampler works, some system ids
                # might be repeated to make no. of samples even across GPUs.
                _, idx = np.unique(gather_results["ids"], return_index=True)
                gather_results["ids"] = np.array(gather_results["ids"])[idx]
                gather_results["pos"] = np.array(gather_results["pos"],
                                                 dtype=object)[idx]

                print(f"Writing results to {full_path}")
                np.savez_compressed(full_path, **gather_results)

        if split == "val":
            aggregated_metrics = {}
            for k in metrics:
                aggregated_metrics[k] = {
                    "total":
                    distutils.all_reduce(metrics[k]["total"],
                                         average=False,
                                         device=self.device),
                    "numel":
                    distutils.all_reduce(metrics[k]["numel"],
                                         average=False,
                                         device=self.device),
                }
                aggregated_metrics[k]["metric"] = (
                    aggregated_metrics[k]["total"] /
                    aggregated_metrics[k]["numel"])
            metrics = aggregated_metrics

            # Make plots.
            log_dict = {k: metrics[k]["metric"] for k in metrics}
            if self.logger is not None and epoch is not None:
                self.logger.log(
                    log_dict,
                    step=(epoch + 1) * len(self.train_loader),
                    split=split,
                )

            if distutils.is_master():
                print(metrics)