Ejemplo n.º 1
0
    def validate(self, split="val", epoch=None, disable_tqdm=False):
        if distutils.is_master():
            print("### Evaluating on {}.".format(split))
        if self.is_hpo:
            disable_tqdm = True

        self.model.eval()
        evaluator, metrics = Evaluator(task=self.name), {}
        rank = distutils.get_rank()

        loader = self.val_loader if split == "val" else self.test_loader

        for i, batch in tqdm(
                enumerate(loader),
                total=len(loader),
                position=rank,
                desc="device {}".format(rank),
                disable=disable_tqdm,
        ):
            # Forward.
            with torch.cuda.amp.autocast(enabled=self.scaler is not None):
                out = self._forward(batch)
            loss = self._compute_loss(out, batch)

            # Compute metrics.
            metrics = self._compute_metrics(out, batch, evaluator, metrics)
            metrics = evaluator.update("loss", loss.item(), metrics)

        aggregated_metrics = {}
        for k in metrics:
            aggregated_metrics[k] = {
                "total":
                distutils.all_reduce(metrics[k]["total"],
                                     average=False,
                                     device=self.device),
                "numel":
                distutils.all_reduce(metrics[k]["numel"],
                                     average=False,
                                     device=self.device),
            }
            aggregated_metrics[k]["metric"] = (aggregated_metrics[k]["total"] /
                                               aggregated_metrics[k]["numel"])
        metrics = aggregated_metrics

        log_dict = {k: metrics[k]["metric"] for k in metrics}
        log_dict.update({"epoch": epoch + 1})
        if distutils.is_master():
            log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()]
            print(", ".join(log_str))

        # Make plots.
        if self.logger is not None and epoch is not None:
            self.logger.log(
                log_dict,
                step=(epoch + 1) * len(self.train_loader),
                split=split,
            )

        return metrics
Ejemplo n.º 2
0
    def load_model(self):
        # Build model
        if distutils.is_master():
            logging.info(f"Loading model: {self.config['model']}")

        # TODO(abhshkdz): Eventually move towards computing features on-the-fly
        # and remove dependence from `.edge_attr`.
        bond_feat_dim = None
        if self.config["task"]["dataset"] in [
            "trajectory_lmdb",
            "single_point_lmdb",
        ]:
            bond_feat_dim = self.config["model_attributes"].get(
                "num_gaussians", 50
            )
        else:
            raise NotImplementedError

        loader = self.train_loader or self.val_loader or self.test_loader
        self.model = registry.get_model_class(self.config["model"])(
            loader.dataset[0].x.shape[-1]
            if loader
            and hasattr(loader.dataset[0], "x")
            and loader.dataset[0].x is not None
            else None,
            bond_feat_dim,
            self.num_targets,
            **self.config["model_attributes"],
        ).to(self.device)

        if distutils.is_master():
            logging.info(
                f"Loaded {self.model.__class__.__name__} with "
                f"{self.model.num_params} parameters."
            )

        if self.logger is not None:
            self.logger.watch(self.model)

        self.model = OCPDataParallel(
            self.model,
            output_device=self.device,
            num_gpus=1 if not self.cpu else 0,
        )
        if distutils.initialized():
            self.model = DistributedDataParallel(
                self.model, device_ids=[self.device]
            )
Ejemplo n.º 3
0
 def load_logger(self):
     self.logger = None
     if not self.is_debug and distutils.is_master():
         assert (self.config["logger"]
                 is not None), "Specify logger in config"
         self.logger = registry.get_logger_class(self.config["logger"])(
             self.config)
Ejemplo n.º 4
0
    def predict(self, loader, results_file=None, disable_tqdm=False):
        if distutils.is_master() and not disable_tqdm:
            print("### Predicting on test.")
        assert isinstance(loader, torch.utils.data.dataloader.DataLoader)
        rank = distutils.get_rank()

        self.model.eval()
        if self.normalizers is not None and "target" in self.normalizers:
            self.normalizers["target"].to(self.device)
        predictions = {"id": [], "energy": []}

        for i, batch in tqdm(
                enumerate(loader),
                total=len(loader),
                position=rank,
                desc="device {}".format(rank),
                disable=disable_tqdm,
        ):
            with torch.cuda.amp.autocast(enabled=self.scaler is not None):
                out = self._forward(batch)

            if self.normalizers is not None and "target" in self.normalizers:
                out["energy"] = self.normalizers["target"].denorm(
                    out["energy"])
            predictions["id"].extend([str(i) for i in batch[0].sid.tolist()])
            predictions["energy"].extend(out["energy"].tolist())

        self.save_results(predictions, results_file, keys=["energy"])
        return predictions
Ejemplo n.º 5
0
def main(config):
    if args.distributed:
        distutils.setup(config)

    try:
        setup_imports()
        trainer = registry.get_trainer_class(config.get("trainer", "simple"))(
            task=config["task"],
            model=config["model"],
            dataset=config["dataset"],
            optimizer=config["optim"],
            identifier=config["identifier"],
            run_dir=config.get("run_dir", "./"),
            is_debug=config.get("is_debug", False),
            is_vis=config.get("is_vis", False),
            print_every=config.get("print_every", 10),
            seed=config.get("seed", 0),
            logger=config.get("logger", "tensorboard"),
            local_rank=config["local_rank"],
            amp=config.get("amp", False),
            cpu=config.get("cpu", False),
        )
        if config["checkpoint"] is not None:
            trainer.load_pretrained(config["checkpoint"])

        start_time = time.time()

        if config["mode"] == "train":
            trainer.train()

        elif config["mode"] == "predict":
            assert (
                trainer.test_loader
                is not None), "Test dataset is required for making predictions"
            assert config["checkpoint"]
            results_file = "predictions"
            trainer.predict(
                trainer.test_loader,
                results_file=results_file,
                disable_tqdm=False,
            )

        elif config["mode"] == "run-relaxations":
            assert isinstance(
                trainer, ForcesTrainer
            ), "Relaxations are only possible for ForcesTrainer"
            assert (trainer.relax_dataset is not None
                    ), "Relax dataset is required for making predictions"
            assert config["checkpoint"]
            trainer.run_relaxations()

        distutils.synchronize()

        if distutils.is_master():
            print("Total time taken = ", time.time() - start_time)

    finally:
        if args.distributed:
            distutils.cleanup()
Ejemplo n.º 6
0
    def predict(self,
                loader,
                per_image=True,
                results_file=None,
                disable_tqdm=False):
        if distutils.is_master() and not disable_tqdm:
            logging.info("Predicting on test.")
        assert isinstance(
            loader,
            (
                torch.utils.data.dataloader.DataLoader,
                torch_geometric.data.Batch,
            ),
        )
        rank = distutils.get_rank()

        if isinstance(loader, torch_geometric.data.Batch):
            loader = [[loader]]

        self.model.eval()
        if self.ema:
            self.ema.store()
            self.ema.copy_to()

        if self.normalizers is not None and "target" in self.normalizers:
            self.normalizers["target"].to(self.device)
        predictions = {"id": [], "energy": []}

        for i, batch in tqdm(
                enumerate(loader),
                total=len(loader),
                position=rank,
                desc="device {}".format(rank),
                disable=disable_tqdm,
        ):
            with torch.cuda.amp.autocast(enabled=self.scaler is not None):
                out = self._forward(batch)

            if self.normalizers is not None and "target" in self.normalizers:
                out["energy"] = self.normalizers["target"].denorm(
                    out["energy"])

            if per_image:
                predictions["id"].extend(
                    [str(i) for i in batch[0].sid.tolist()])
                predictions["energy"].extend(out["energy"].tolist())
            else:
                predictions["energy"] = out["energy"].detach()
                return predictions

        self.save_results(predictions, results_file, keys=["energy"])

        if self.ema:
            self.ema.restore()

        return predictions
Ejemplo n.º 7
0
 def save(
     self,
     metrics=None,
     checkpoint_file="checkpoint.pt",
     training_state=True,
 ):
     if not self.is_debug and distutils.is_master():
         if training_state:
             save_checkpoint(
                 {
                     "epoch": self.epoch,
                     "step": self.step,
                     "state_dict": self.model.state_dict(),
                     "optimizer": self.optimizer.state_dict(),
                     "scheduler": self.scheduler.scheduler.state_dict()
                     if self.scheduler.scheduler_type != "Null"
                     else None,
                     "normalizers": {
                         key: value.state_dict()
                         for key, value in self.normalizers.items()
                     },
                     "config": self.config,
                     "val_metrics": metrics,
                     "ema": self.ema.state_dict() if self.ema else None,
                     "amp": self.scaler.state_dict()
                     if self.scaler
                     else None,
                 },
                 checkpoint_dir=self.config["cmd"]["checkpoint_dir"],
                 checkpoint_file=checkpoint_file,
             )
         else:
             if self.ema:
                 self.ema.store()
                 self.ema.copy_to()
             save_checkpoint(
                 {
                     "state_dict": self.model.state_dict(),
                     "normalizers": {
                         key: value.state_dict()
                         for key, value in self.normalizers.items()
                     },
                     "config": self.config,
                     "val_metrics": metrics,
                     "amp": self.scaler.state_dict()
                     if self.scaler
                     else None,
                 },
                 checkpoint_dir=self.config["cmd"]["checkpoint_dir"],
                 checkpoint_file=checkpoint_file,
             )
             if self.ema:
                 self.ema.restore()
Ejemplo n.º 8
0
    def save_results(self, predictions, results_file, keys):
        if results_file is None:
            return

        results_file_path = os.path.join(
            self.config["cmd"]["results_dir"],
            f"{self.name}_{results_file}_{distutils.get_rank()}.npz",
        )
        np.savez_compressed(
            results_file_path,
            ids=predictions["id"],
            **{key: predictions[key] for key in keys},
        )

        distutils.synchronize()
        if distutils.is_master():
            gather_results = defaultdict(list)
            full_path = os.path.join(
                self.config["cmd"]["results_dir"],
                f"{self.name}_{results_file}.npz",
            )

            for i in range(distutils.get_world_size()):
                rank_path = os.path.join(
                    self.config["cmd"]["results_dir"],
                    f"{self.name}_{results_file}_{i}.npz",
                )
                rank_results = np.load(rank_path, allow_pickle=True)
                gather_results["ids"].extend(rank_results["ids"])
                for key in keys:
                    gather_results[key].extend(rank_results[key])
                os.remove(rank_path)

            # Because of how distributed sampler works, some system ids
            # might be repeated to make no. of samples even across GPUs.
            _, idx = np.unique(gather_results["ids"], return_index=True)
            gather_results["ids"] = np.array(gather_results["ids"])[idx]
            for k in keys:
                if k == "forces":
                    gather_results[k] = np.concatenate(
                        np.array(gather_results[k])[idx]
                    )
                elif k == "chunk_idx":
                    gather_results[k] = np.cumsum(
                        np.array(gather_results[k])[idx]
                    )[:-1]
                else:
                    gather_results[k] = np.array(gather_results[k])[idx]

            logging.info(f"Writing results to {full_path}")
            np.savez_compressed(full_path, **gather_results)
Ejemplo n.º 9
0
 def save(self, epoch, metrics):
     if not self.is_debug and distutils.is_master():
         save_checkpoint(
             {
                 "epoch": epoch,
                 "state_dict": self.model.state_dict(),
                 "optimizer": self.optimizer.state_dict(),
                 "normalizers": {
                     key: value.state_dict()
                     for key, value in self.normalizers.items()
                 },
                 "config": self.config,
                 "val_metrics": metrics,
                 "amp": self.scaler.state_dict() if self.scaler else None,
             },
             self.config["cmd"]["checkpoint_dir"],
         )
Ejemplo n.º 10
0
Archivo: main.py Proyecto: wood-b/ocp
    def __call__(self, config):
        setup_logging()
        self.config = copy.deepcopy(config)

        if args.distributed:
            distutils.setup(config)

        try:
            setup_imports()
            self.trainer = registry.get_trainer_class(
                config.get("trainer", "simple"))(
                    task=config["task"],
                    model=config["model"],
                    dataset=config["dataset"],
                    optimizer=config["optim"],
                    identifier=config["identifier"],
                    timestamp_id=config.get("timestamp_id", None),
                    run_dir=config.get("run_dir", "./"),
                    is_debug=config.get("is_debug", False),
                    is_vis=config.get("is_vis", False),
                    print_every=config.get("print_every", 10),
                    seed=config.get("seed", 0),
                    logger=config.get("logger", "tensorboard"),
                    local_rank=config["local_rank"],
                    amp=config.get("amp", False),
                    cpu=config.get("cpu", False),
                    slurm=config.get("slurm", {}),
                )
            self.task = registry.get_task_class(config["mode"])(self.config)
            self.task.setup(self.trainer)
            start_time = time.time()
            self.task.run()
            distutils.synchronize()
            if distutils.is_master():
                logging.info(f"Total time taken: {time.time() - start_time}")
        finally:
            if args.distributed:
                distutils.cleanup()
Ejemplo n.º 11
0
    def __init__(
        self,
        task,
        model,
        dataset,
        optimizer,
        identifier,
        run_dir=None,
        is_debug=False,
        is_vis=False,
        is_hpo=False,
        print_every=100,
        seed=None,
        logger="tensorboard",
        local_rank=0,
        amp=False,
        cpu=False,
        name="base_trainer",
    ):
        self.name = name
        self.cpu = cpu
        self.start_step = 0

        if torch.cuda.is_available() and not self.cpu:
            self.device = local_rank
        else:
            self.device = "cpu"
            self.cpu = True  # handle case when `--cpu` isn't specified
            # but there are no gpu devices available
        if run_dir is None:
            run_dir = os.getcwd()

        timestamp = torch.tensor(datetime.datetime.now().timestamp()).to(
            self.device)
        # create directories from master rank only
        distutils.broadcast(timestamp, 0)
        timestamp = datetime.datetime.fromtimestamp(
            timestamp.int()).strftime("%Y-%m-%d-%H-%M-%S")
        if identifier:
            timestamp += "-{}".format(identifier)
        try:
            commit_hash = (subprocess.check_output([
                "git",
                "-C",
                ocpmodels.__path__[0],
                "describe",
                "--always",
            ]).strip().decode("ascii"))
        # catch instances where code is not being run from a git repo
        except Exception:
            commit_hash = None

        self.config = {
            "task": task,
            "model": model.pop("name"),
            "model_attributes": model,
            "optim": optimizer,
            "logger": logger,
            "amp": amp,
            "gpus": distutils.get_world_size() if not self.cpu else 0,
            "cmd": {
                "identifier": identifier,
                "print_every": print_every,
                "seed": seed,
                "timestamp": timestamp,
                "commit": commit_hash,
                "checkpoint_dir": os.path.join(run_dir, "checkpoints",
                                               timestamp),
                "results_dir": os.path.join(run_dir, "results", timestamp),
                "logs_dir": os.path.join(run_dir, "logs", logger, timestamp),
            },
        }
        # AMP Scaler
        self.scaler = torch.cuda.amp.GradScaler() if amp else None

        if isinstance(dataset, list):
            self.config["dataset"] = dataset[0]
            if len(dataset) > 1:
                self.config["val_dataset"] = dataset[1]
            if len(dataset) > 2:
                self.config["test_dataset"] = dataset[2]
        else:
            self.config["dataset"] = dataset

        if not is_debug and distutils.is_master() and not is_hpo:
            os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True)
            os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True)
            os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True)

        self.is_debug = is_debug
        self.is_vis = is_vis
        self.is_hpo = is_hpo

        if self.is_hpo:
            # sets the hpo checkpoint frequency
            # default is no checkpointing
            self.hpo_checkpoint_every = self.config["optim"].get(
                "checkpoint_every", -1)

        if distutils.is_master():
            print(yaml.dump(self.config, default_flow_style=False))
        self.load()

        self.evaluator = Evaluator(task=name)
Ejemplo n.º 12
0
    def train(self):
        self.best_val_mae = 1e9
        eval_every = self.config["optim"].get("eval_every", -1)
        iters = 0
        self.metrics = {}
        for epoch in range(self.config["optim"]["max_epochs"]):
            self.model.train()
            for i, batch in enumerate(self.train_loader):
                # Forward, loss, backward.
                with torch.cuda.amp.autocast(enabled=self.scaler is not None):
                    out = self._forward(batch)
                    loss = self._compute_loss(out, batch)
                loss = self.scaler.scale(loss) if self.scaler else loss
                self._backward(loss)
                scale = self.scaler.get_scale() if self.scaler else 1.0

                # Compute metrics.
                self.metrics = self._compute_metrics(
                    out,
                    batch,
                    self.evaluator,
                    self.metrics,
                )
                self.metrics = self.evaluator.update("loss",
                                                     loss.item() / scale,
                                                     self.metrics)

                # Print metrics, make plots.
                log_dict = {k: self.metrics[k]["metric"] for k in self.metrics}
                log_dict.update(
                    {"epoch": epoch + (i + 1) / len(self.train_loader)})
                if i % self.config["cmd"]["print_every"] == 0:
                    log_str = [
                        "{}: {:.4f}".format(k, v) for k, v in log_dict.items()
                    ]
                    print(", ".join(log_str))
                    self.metrics = {}

                if self.logger is not None:
                    self.logger.log(
                        log_dict,
                        step=epoch * len(self.train_loader) + i + 1,
                        split="train",
                    )

                iters += 1

                # Evaluate on val set every `eval_every` iterations.
                if eval_every != -1 and iters % eval_every == 0:
                    if self.val_loader is not None:
                        val_metrics = self.validate(split="val", epoch=epoch)
                        if (val_metrics[self.evaluator.
                                        task_primary_metric["s2ef"]]["metric"]
                                < self.best_val_mae):
                            self.best_val_mae = val_metrics[
                                self.evaluator.
                                task_primary_metric["s2ef"]]["metric"]
                            if not self.is_debug and distutils.is_master():
                                save_checkpoint(
                                    {
                                        "epoch":
                                        epoch +
                                        (i + 1) / len(self.train_loader),
                                        "state_dict":
                                        self.model.state_dict(),
                                        "optimizer":
                                        self.optimizer.state_dict(),
                                        "normalizers": {
                                            key: value.state_dict()
                                            for key, value in
                                            self.normalizers.items()
                                        },
                                        "config":
                                        self.config,
                                        "val_metrics":
                                        val_metrics,
                                    },
                                    self.config["cmd"]["checkpoint_dir"],
                                )

            self.scheduler.step()
            torch.cuda.empty_cache()

            if eval_every == -1:
                if self.val_loader is not None:
                    val_metrics = self.validate(split="val", epoch=epoch)
                    if (val_metrics[self.evaluator.task_primary_metric["s2ef"]]
                        ["metric"] < self.best_val_mae):
                        self.best_val_mae = val_metrics[
                            self.evaluator.
                            task_primary_metric["s2ef"]]["metric"]
                        if not self.is_debug and distutils.is_master():
                            save_checkpoint(
                                {
                                    "epoch": epoch + 1,
                                    "state_dict": self.model.state_dict(),
                                    "optimizer": self.optimizer.state_dict(),
                                    "normalizers": {
                                        key: value.state_dict()
                                        for key, value in
                                        self.normalizers.items()
                                    },
                                    "config": self.config,
                                    "val_metrics": val_metrics,
                                },
                                self.config["cmd"]["checkpoint_dir"],
                            )

            if self.test_loader is not None:
                self.validate(split="test", epoch=epoch)

            if ("relaxation_dir" in self.config["task"]
                    and self.config["task"].get("ml_relax", "end") == "train"):
                self.validate_relaxation(
                    split="val",
                    epoch=epoch,
                )

        if ("relaxation_dir" in self.config["task"]
                and self.config["task"].get("ml_relax", "end") == "end"):
            self.validate_relaxation(
                split="val",
                epoch=epoch,
            )
Ejemplo n.º 13
0
    def run_relaxations(self, split="val", epoch=None):
        print("### Running ML-relaxations")
        self.model.eval()

        evaluator, metrics = Evaluator(task="is2rs"), {}

        if hasattr(self.relax_dataset[0], "pos_relaxed") and hasattr(
                self.relax_dataset[0], "y_relaxed"):
            split = "val"
        else:
            split = "test"

        ids = []
        relaxed_positions = []
        for i, batch in tqdm(enumerate(self.relax_loader),
                             total=len(self.relax_loader)):
            relaxed_batch = ml_relax(
                batch=batch,
                model=self,
                steps=self.config["task"].get("relaxation_steps", 200),
                fmax=self.config["task"].get("relaxation_fmax", 0.0),
                relax_opt=self.config["task"]["relax_opt"],
                device=self.device,
                transform=None,
            )

            if self.config["task"].get("write_pos", False):
                systemids = [str(i) for i in relaxed_batch.sid.tolist()]
                natoms = relaxed_batch.natoms.tolist()
                positions = torch.split(relaxed_batch.pos, natoms)
                batch_relaxed_positions = [pos.tolist() for pos in positions]

                relaxed_positions += batch_relaxed_positions
                ids += systemids

            if split == "val":
                mask = relaxed_batch.fixed == 0
                s_idx = 0
                natoms_free = []
                for natoms in relaxed_batch.natoms:
                    natoms_free.append(
                        torch.sum(mask[s_idx:s_idx + natoms]).item())
                    s_idx += natoms

                target = {
                    "energy": relaxed_batch.y_relaxed,
                    "positions": relaxed_batch.pos_relaxed[mask],
                    "cell": relaxed_batch.cell,
                    "pbc": torch.tensor([True, True, True]),
                    "natoms": torch.LongTensor(natoms_free),
                }

                prediction = {
                    "energy": relaxed_batch.y,
                    "positions": relaxed_batch.pos[mask],
                    "cell": relaxed_batch.cell,
                    "pbc": torch.tensor([True, True, True]),
                    "natoms": torch.LongTensor(natoms_free),
                }

                metrics = evaluator.eval(prediction, target, metrics)

        if self.config["task"].get("write_pos", False):
            rank = distutils.get_rank()
            pos_filename = os.path.join(self.config["cmd"]["results_dir"],
                                        f"relaxed_pos_{rank}.npz")
            np.savez_compressed(
                pos_filename,
                ids=ids,
                pos=np.array(relaxed_positions, dtype=object),
            )

            distutils.synchronize()
            if distutils.is_master():
                gather_results = defaultdict(list)
                full_path = os.path.join(
                    self.config["cmd"]["results_dir"],
                    "relaxed_positions.npz",
                )

                for i in range(distutils.get_world_size()):
                    rank_path = os.path.join(
                        self.config["cmd"]["results_dir"],
                        f"relaxed_pos_{i}.npz",
                    )
                    rank_results = np.load(rank_path, allow_pickle=True)
                    gather_results["ids"].extend(rank_results["ids"])
                    gather_results["pos"].extend(rank_results["pos"])
                    os.remove(rank_path)

                # Because of how distributed sampler works, some system ids
                # might be repeated to make no. of samples even across GPUs.
                _, idx = np.unique(gather_results["ids"], return_index=True)
                gather_results["ids"] = np.array(gather_results["ids"])[idx]
                gather_results["pos"] = np.array(gather_results["pos"],
                                                 dtype=object)[idx]

                print(f"Writing results to {full_path}")
                np.savez_compressed(full_path, **gather_results)

        if split == "val":
            aggregated_metrics = {}
            for k in metrics:
                aggregated_metrics[k] = {
                    "total":
                    distutils.all_reduce(metrics[k]["total"],
                                         average=False,
                                         device=self.device),
                    "numel":
                    distutils.all_reduce(metrics[k]["numel"],
                                         average=False,
                                         device=self.device),
                }
                aggregated_metrics[k]["metric"] = (
                    aggregated_metrics[k]["total"] /
                    aggregated_metrics[k]["numel"])
            metrics = aggregated_metrics

            # Make plots.
            log_dict = {k: metrics[k]["metric"] for k in metrics}
            if self.logger is not None and epoch is not None:
                self.logger.log(
                    log_dict,
                    step=(epoch + 1) * len(self.train_loader),
                    split=split,
                )

            if distutils.is_master():
                print(metrics)
Ejemplo n.º 14
0
    def predict(self,
                data_loader,
                per_image=True,
                results_file=None,
                disable_tqdm=True):
        if distutils.is_master() and not disable_tqdm:
            print("### Predicting on test.")
        assert isinstance(
            data_loader,
            (
                torch.utils.data.dataloader.DataLoader,
                torch_geometric.data.Batch,
            ),
        )
        rank = distutils.get_rank()

        if isinstance(data_loader, torch_geometric.data.Batch):
            data_loader = [[data_loader]]

        self.model.eval()
        if self.normalizers is not None and "target" in self.normalizers:
            self.normalizers["target"].to(self.device)
            self.normalizers["grad_target"].to(self.device)

        predictions = {"id": [], "energy": [], "forces": []}

        for i, batch_list in tqdm(
                enumerate(data_loader),
                total=len(data_loader),
                position=rank,
                desc="device {}".format(rank),
                disable=disable_tqdm,
        ):
            with torch.cuda.amp.autocast(enabled=self.scaler is not None):
                out = self._forward(batch_list)

            if self.normalizers is not None and "target" in self.normalizers:
                out["energy"] = self.normalizers["target"].denorm(
                    out["energy"])
                out["forces"] = self.normalizers["grad_target"].denorm(
                    out["forces"])
            if per_image:
                atoms_sum = 0
                systemids = [
                    str(i) + "_" + str(j) for i, j in zip(
                        batch_list[0].sid.tolist(), batch_list[0].fid.tolist())
                ]
                predictions["id"].extend(systemids)
                predictions["energy"].extend(out["energy"].to(
                    torch.float16).tolist())
                batch_natoms = torch.cat(
                    [batch.natoms for batch in batch_list])
                batch_fixed = torch.cat([batch.fixed for batch in batch_list])
                for natoms in batch_natoms:
                    forces = (out["forces"][atoms_sum:natoms +
                                            atoms_sum].cpu().detach().to(
                                                torch.float16).numpy())
                    # evalAI only requires forces on free atoms
                    if results_file is not None:
                        _free_atoms = (batch_fixed[atoms_sum:natoms +
                                                   atoms_sum] == 0).tolist()
                        forces = forces[_free_atoms]
                    atoms_sum += natoms
                    predictions["forces"].append(forces)
            else:
                predictions["energy"] = out["energy"].detach()
                predictions["forces"] = out["forces"].detach()
                return predictions

        predictions["forces"] = np.array(predictions["forces"], dtype=object)
        predictions["energy"] = np.array(predictions["energy"])
        predictions["id"] = np.array(predictions["id"])
        self.save_results(predictions, results_file, keys=["energy", "forces"])
        return predictions
Ejemplo n.º 15
0
    def __init__(
        self,
        task,
        model,
        dataset,
        optimizer,
        identifier,
        normalizer=None,
        timestamp_id=None,
        run_dir=None,
        is_debug=False,
        is_vis=False,
        is_hpo=False,
        print_every=100,
        seed=None,
        logger="tensorboard",
        local_rank=0,
        amp=False,
        cpu=False,
        name="base_trainer",
        slurm={},
    ):
        self.name = name
        self.cpu = cpu
        self.epoch = 0
        self.step = 0

        if torch.cuda.is_available() and not self.cpu:
            self.device = torch.device(f"cuda:{local_rank}")
        else:
            self.device = torch.device("cpu")
            self.cpu = True  # handle case when `--cpu` isn't specified
            # but there are no gpu devices available
        if run_dir is None:
            run_dir = os.getcwd()

        if timestamp_id is None:
            timestamp = torch.tensor(datetime.datetime.now().timestamp()).to(
                self.device
            )
            # create directories from master rank only
            distutils.broadcast(timestamp, 0)
            timestamp = datetime.datetime.fromtimestamp(
                timestamp.int()
            ).strftime("%Y-%m-%d-%H-%M-%S")
            if identifier:
                self.timestamp_id = f"{timestamp}-{identifier}"
            else:
                self.timestamp_id = timestamp
        else:
            self.timestamp_id = timestamp_id

        try:
            commit_hash = (
                subprocess.check_output(
                    [
                        "git",
                        "-C",
                        ocpmodels.__path__[0],
                        "describe",
                        "--always",
                    ]
                )
                .strip()
                .decode("ascii")
            )
        # catch instances where code is not being run from a git repo
        except Exception:
            commit_hash = None

        self.config = {
            "task": task,
            "model": model.pop("name"),
            "model_attributes": model,
            "optim": optimizer,
            "logger": logger,
            "amp": amp,
            "gpus": distutils.get_world_size() if not self.cpu else 0,
            "cmd": {
                "identifier": identifier,
                "print_every": print_every,
                "seed": seed,
                "timestamp_id": self.timestamp_id,
                "commit": commit_hash,
                "checkpoint_dir": os.path.join(
                    run_dir, "checkpoints", self.timestamp_id
                ),
                "results_dir": os.path.join(
                    run_dir, "results", self.timestamp_id
                ),
                "logs_dir": os.path.join(
                    run_dir, "logs", logger, self.timestamp_id
                ),
            },
            "slurm": slurm,
        }
        # AMP Scaler
        self.scaler = torch.cuda.amp.GradScaler() if amp else None

        if "SLURM_JOB_ID" in os.environ and "folder" in self.config["slurm"]:
            self.config["slurm"]["job_id"] = os.environ["SLURM_JOB_ID"]
            self.config["slurm"]["folder"] = self.config["slurm"][
                "folder"
            ].replace("%j", self.config["slurm"]["job_id"])
        if isinstance(dataset, list):
            if len(dataset) > 0:
                self.config["dataset"] = dataset[0]
            if len(dataset) > 1:
                self.config["val_dataset"] = dataset[1]
            if len(dataset) > 2:
                self.config["test_dataset"] = dataset[2]
        elif isinstance(dataset, dict):
            self.config["dataset"] = dataset.get("train", None)
            self.config["val_dataset"] = dataset.get("val", None)
            self.config["test_dataset"] = dataset.get("test", None)
        else:
            self.config["dataset"] = dataset

        self.normalizer = normalizer
        # This supports the legacy way of providing norm parameters in dataset
        if self.config.get("dataset", None) is not None and normalizer is None:
            self.normalizer = self.config["dataset"]

        if not is_debug and distutils.is_master() and not is_hpo:
            os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True)
            os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True)
            os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True)

        self.is_debug = is_debug
        self.is_vis = is_vis
        self.is_hpo = is_hpo

        if self.is_hpo:
            # conditional import is necessary for checkpointing
            from ray import tune

            from ocpmodels.common.hpo_utils import tune_reporter

            # sets the hpo checkpoint frequency
            # default is no checkpointing
            self.hpo_checkpoint_every = self.config["optim"].get(
                "checkpoint_every", -1
            )

        if distutils.is_master():
            print(yaml.dump(self.config, default_flow_style=False))
        self.load()

        self.evaluator = Evaluator(task=name)
Ejemplo n.º 16
0
    def train(self, disable_eval_tqdm=False):
        eval_every = self.config["optim"].get("eval_every", None)
        if eval_every is None:
            eval_every = len(self.train_loader)
        checkpoint_every = self.config["optim"].get("checkpoint_every",
                                                    eval_every)
        primary_metric = self.config["task"].get(
            "primary_metric", self.evaluator.task_primary_metric[self.name])
        self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0
        self.metrics = {}

        # Calculate start_epoch from step instead of loading the epoch number
        # to prevent inconsistencies due to different batch size in checkpoint.
        start_epoch = self.step // len(self.train_loader)

        for epoch_int in range(start_epoch,
                               self.config["optim"]["max_epochs"]):
            self.train_sampler.set_epoch(epoch_int)
            skip_steps = self.step % len(self.train_loader)
            train_loader_iter = iter(self.train_loader)

            for i in range(skip_steps, len(self.train_loader)):
                self.epoch = epoch_int + (i + 1) / len(self.train_loader)
                self.step = epoch_int * len(self.train_loader) + i + 1
                self.model.train()

                # Get a batch.
                batch = next(train_loader_iter)

                if self.config["optim"]["optimizer"] == "LBFGS":

                    def closure():
                        self.optimizer.zero_grad()
                        with torch.cuda.amp.autocast(
                                enabled=self.scaler is not None):
                            out = self._forward(batch)
                            loss = self._compute_loss(out, batch)
                        loss.backward()
                        return loss

                    self.optimizer.step(closure)

                    self.optimizer.zero_grad()
                    with torch.cuda.amp.autocast(
                            enabled=self.scaler is not None):
                        out = self._forward(batch)
                        loss = self._compute_loss(out, batch)

                else:
                    # Forward, loss, backward.
                    with torch.cuda.amp.autocast(
                            enabled=self.scaler is not None):
                        out = self._forward(batch)
                        loss = self._compute_loss(out, batch)
                    loss = self.scaler.scale(loss) if self.scaler else loss
                    self._backward(loss)

                scale = self.scaler.get_scale() if self.scaler else 1.0

                # Compute metrics.
                self.metrics = self._compute_metrics(
                    out,
                    batch,
                    self.evaluator,
                    self.metrics,
                )
                self.metrics = self.evaluator.update("loss",
                                                     loss.item() / scale,
                                                     self.metrics)

                # Log metrics.
                log_dict = {k: self.metrics[k]["metric"] for k in self.metrics}
                log_dict.update({
                    "lr": self.scheduler.get_lr(),
                    "epoch": self.epoch,
                    "step": self.step,
                })
                if (self.step % self.config["cmd"]["print_every"] == 0
                        and distutils.is_master() and not self.is_hpo):
                    log_str = [
                        "{}: {:.2e}".format(k, v) for k, v in log_dict.items()
                    ]
                    logging.info(", ".join(log_str))
                    self.metrics = {}

                if self.logger is not None:
                    self.logger.log(
                        log_dict,
                        step=self.step,
                        split="train",
                    )

                if checkpoint_every != -1 and self.step % checkpoint_every == 0:
                    self.save(checkpoint_file="checkpoint.pt",
                              training_state=True)

                # Evaluate on val set every `eval_every` iterations.
                if self.step % eval_every == 0:
                    if self.val_loader is not None:
                        val_metrics = self.validate(
                            split="val",
                            disable_tqdm=disable_eval_tqdm,
                        )
                        self.update_best(
                            primary_metric,
                            val_metrics,
                            disable_eval_tqdm=disable_eval_tqdm,
                        )
                        if self.is_hpo:
                            self.hpo_update(
                                self.epoch,
                                self.step,
                                self.metrics,
                                val_metrics,
                            )

                    if self.config["task"].get("eval_relaxations", False):
                        if "relax_dataset" not in self.config["task"]:
                            logging.warning(
                                "Cannot evaluate relaxations, relax_dataset not specified"
                            )
                        else:
                            self.run_relaxations()

                if self.config["optim"].get("print_loss_and_lr", False):
                    print(
                        "epoch: " + str(self.epoch) + ", \tstep: " +
                        str(self.step) + ", \tloss: " +
                        str(loss.detach().item()) + ", \tlr: " +
                        str(self.scheduler.get_lr()) + ", \tval: " +
                        str(val_metrics["loss"]["total"])
                    ) if self.step % eval_every == 0 and self.val_loader is not None else print(
                        "epoch: " + str(self.epoch) + ", \tstep: " +
                        str(self.step) + ", \tloss: " +
                        str(loss.detach().item()) + ", \tlr: " +
                        str(self.scheduler.get_lr()))

                if self.scheduler.scheduler_type == "ReduceLROnPlateau":
                    if (self.step % eval_every == 0
                            and self.config["optim"].get(
                                "scheduler_loss", None) == "train"):
                        self.scheduler.step(metrics=loss.detach().item(), )
                    elif self.step % eval_every == 0 and self.val_loader is not None:
                        self.scheduler.step(
                            metrics=val_metrics[primary_metric]["metric"], )
                else:
                    self.scheduler.step()

                break_below_lr = (self.config["optim"].get(
                    "break_below_lr", None) is not None) and (
                        self.scheduler.get_lr() <
                        self.config["optim"]["break_below_lr"])
                if break_below_lr:
                    break
            if break_below_lr:
                break

            torch.cuda.empty_cache()

            if checkpoint_every == -1:
                self.save(checkpoint_file="checkpoint.pt", training_state=True)

        self.train_dataset.close_db()
        if "val_dataset" in self.config:
            self.val_dataset.close_db()
        if "test_dataset" in self.config:
            self.test_dataset.close_db()
Ejemplo n.º 17
0
    def __init__(
        self,
        task,
        model,
        dataset,
        optimizer,
        identifier,
        run_dir=None,
        is_debug=False,
        is_vis=False,
        print_every=100,
        seed=None,
        logger="tensorboard",
        local_rank=0,
        amp=False,
        name="base_trainer",
    ):
        self.name = name
        if torch.cuda.is_available():
            self.device = local_rank
        else:
            self.device = "cpu"

        if run_dir is None:
            run_dir = os.getcwd()
        run_dir = Path(run_dir)

        timestamp = torch.tensor(datetime.datetime.now().timestamp()).to(
            self.device)
        # create directories from master rank only
        distutils.broadcast(timestamp, 0)
        timestamp = datetime.datetime.fromtimestamp(timestamp).strftime(
            "%Y-%m-%d-%H-%M-%S")
        if identifier:
            timestamp += "-{}".format(identifier)

        self.config = {
            "task": task,
            "model": model.pop("name"),
            "model_attributes": model,
            "optim": optimizer,
            "logger": logger,
            "amp": amp,
            "cmd": {
                "identifier": identifier,
                "print_every": print_every,
                "seed": seed,
                "timestamp": timestamp,
                "checkpoint_dir": str(run_dir / "checkpoints" / timestamp),
                "results_dir": str(run_dir / "results" / timestamp),
                "logs_dir": str(run_dir / "logs" / logger / timestamp),
            },
        }
        # AMP Scaler
        self.scaler = torch.cuda.amp.GradScaler() if amp else None

        if isinstance(dataset, list):
            self.config["dataset"] = dataset[0]
            if len(dataset) > 1:
                self.config["val_dataset"] = dataset[1]
            if len(dataset) > 2:
                self.config["test_dataset"] = dataset[2]
        else:
            self.config["dataset"] = dataset

        if not is_debug and distutils.is_master():
            os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True)
            os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True)
            os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True)

        self.is_debug = is_debug
        self.is_vis = is_vis

        if distutils.is_master():
            print(yaml.dump(self.config, default_flow_style=False))
        self.load()

        self.evaluator = Evaluator(task=name)
Ejemplo n.º 18
0
    def train(self):
        self.best_val_mae = 1e9

        start_epoch = self.start_step // len(self.train_loader)
        for epoch in range(start_epoch, self.config["optim"]["max_epochs"]):
            self.train_sampler.set_epoch(epoch)
            self.model.train()

            skip_steps = 0
            if epoch == start_epoch and start_epoch > 0:
                skip_steps = start_epoch % len(self.train_loader)
            train_loader_iter = iter(self.train_loader)

            for i in range(skip_steps, len(self.train_loader)):
                batch = next(train_loader_iter)
                # Forward, loss, backward.
                with torch.cuda.amp.autocast(enabled=self.scaler is not None):
                    out = self._forward(batch)
                    loss = self._compute_loss(out, batch)
                loss = self.scaler.scale(loss) if self.scaler else loss
                self._backward(loss)
                scale = self.scaler.get_scale() if self.scaler else 1.0

                # Compute metrics.
                self.metrics = self._compute_metrics(
                    out,
                    batch,
                    self.evaluator,
                    metrics={},
                )
                self.metrics = self.evaluator.update("loss",
                                                     loss.item() / scale,
                                                     self.metrics)

                # Print metrics, make plots.
                log_dict = {k: self.metrics[k]["metric"] for k in self.metrics}
                log_dict.update(
                    {"epoch": epoch + (i + 1) / len(self.train_loader)})
                if (i % self.config["cmd"]["print_every"] == 0
                        and distutils.is_master()):
                    log_str = [
                        "{}: {:.4f}".format(k, v) for k, v in log_dict.items()
                    ]
                    print(", ".join(log_str))

                if self.logger is not None:
                    self.logger.log(
                        log_dict,
                        step=epoch * len(self.train_loader) + i + 1,
                        split="train",
                    )

                if self.update_lr_on_step:
                    self.scheduler.step()

            if not self.update_lr_on_step:
                self.scheduler.step()

            torch.cuda.empty_cache()

            if self.val_loader is not None:
                val_metrics = self.validate(split="val", epoch=epoch)
                if (val_metrics[self.evaluator.task_primary_metric[self.name]]
                    ["metric"] < self.best_val_mae):
                    self.best_val_mae = val_metrics[
                        self.evaluator.task_primary_metric[
                            self.name]]["metric"]
                    current_step = (epoch + 1) * len(self.train_loader)
                    self.save(epoch + 1, current_step, val_metrics)
                    if self.test_loader is not None:
                        self.predict(
                            self.test_loader,
                            results_file="predictions",
                            disable_tqdm=False,
                        )
            else:
                current_step = (epoch + 1) * len(self.train_loader)
                self.save(epoch + 1, current_step, self.metrics)

        self.train_dataset.close_db()
        if "val_dataset" in self.config:
            self.val_dataset.close_db()
        if "test_dataset" in self.config:
            self.test_dataset.close_db()
Ejemplo n.º 19
0
    def train(self):
        eval_every = self.config["optim"].get("eval_every", -1)
        primary_metric = self.config["task"].get(
            "primary_metric", self.evaluator.task_primary_metric[self.name])
        self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0
        iters = 0
        self.metrics = {}
        for epoch in range(self.config["optim"]["max_epochs"]):
            self.model.train()
            for i, batch in enumerate(self.train_loader):
                # Forward, loss, backward.
                with torch.cuda.amp.autocast(enabled=self.scaler is not None):
                    out = self._forward(batch)
                    loss = self._compute_loss(out, batch)
                loss = self.scaler.scale(loss) if self.scaler else loss
                self._backward(loss)
                scale = self.scaler.get_scale() if self.scaler else 1.0

                # Compute metrics.
                self.metrics = self._compute_metrics(
                    out,
                    batch,
                    self.evaluator,
                    self.metrics,
                )
                self.metrics = self.evaluator.update("loss",
                                                     loss.item() / scale,
                                                     self.metrics)

                # Print metrics, make plots.
                log_dict = {k: self.metrics[k]["metric"] for k in self.metrics}
                log_dict.update(
                    {"epoch": epoch + (i + 1) / len(self.train_loader)})
                if (i % self.config["cmd"]["print_every"] == 0
                        and distutils.is_master()):
                    log_str = [
                        "{}: {:.4f}".format(k, v) for k, v in log_dict.items()
                    ]
                    print(", ".join(log_str))
                    self.metrics = {}

                if self.logger is not None:
                    self.logger.log(
                        log_dict,
                        step=epoch * len(self.train_loader) + i + 1,
                        split="train",
                    )

                iters += 1

                # Evaluate on val set every `eval_every` iterations.
                if eval_every != -1 and iters % eval_every == 0:
                    if self.val_loader is not None:
                        val_metrics = self.validate(
                            split="val",
                            epoch=epoch - 1 + (i + 1) / len(self.train_loader),
                        )
                        if ("mae" in primary_metric
                                and val_metrics[primary_metric]["metric"] <
                                self.best_val_metric) or (
                                    val_metrics[primary_metric]["metric"] >
                                    self.best_val_metric):
                            self.best_val_metric = val_metrics[primary_metric][
                                "metric"]
                            current_epoch = epoch + (i + 1) / len(
                                self.train_loader)
                            self.save(current_epoch, val_metrics)
                            if self.test_loader is not None:
                                self.predict(
                                    self.test_loader,
                                    results_file="predictions",
                                    disable_tqdm=False,
                                )

            self.scheduler.step()
            torch.cuda.empty_cache()

            if eval_every == -1:
                if self.val_loader is not None:
                    val_metrics = self.validate(split="val", epoch=epoch)

                    if ("mae" in primary_metric
                            and val_metrics[primary_metric]["metric"] <
                            self.best_val_metric) or (
                                val_metrics[primary_metric]["metric"] >
                                self.best_val_metric):
                        self.best_val_metric = val_metrics[primary_metric][
                            "metric"]
                        self.save(epoch + 1, val_metrics)
                        if self.test_loader is not None:
                            self.predict(
                                self.test_loader,
                                results_file="predictions",
                                disable_tqdm=False,
                            )
                else:
                    self.save(epoch + 1, self.metrics)
Ejemplo n.º 20
0
    def train(self):
        eval_every = self.config["optim"].get("eval_every",
                                              len(self.train_loader))
        primary_metric = self.config["task"].get(
            "primary_metric", self.evaluator.task_primary_metric[self.name])
        self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0
        iters = 0
        self.metrics = {}

        start_epoch = self.start_step // len(self.train_loader)
        for epoch in range(start_epoch, self.config["optim"]["max_epochs"]):
            self.train_sampler.set_epoch(epoch)
            skip_steps = 0
            if epoch == start_epoch and start_epoch > 0:
                skip_steps = start_epoch % len(self.train_loader)
            train_loader_iter = iter(self.train_loader)

            for i in range(skip_steps, len(self.train_loader)):
                self.model.train()
                current_epoch = epoch + (i + 1) / len(self.train_loader)
                current_step = epoch * len(self.train_loader) + (i + 1)

                # Get a batch.
                batch = next(train_loader_iter)

                # Forward, loss, backward.
                with torch.cuda.amp.autocast(enabled=self.scaler is not None):
                    out = self._forward(batch)
                    loss = self._compute_loss(out, batch)
                loss = self.scaler.scale(loss) if self.scaler else loss
                self._backward(loss)
                scale = self.scaler.get_scale() if self.scaler else 1.0

                # Compute metrics.
                self.metrics = self._compute_metrics(
                    out,
                    batch,
                    self.evaluator,
                    self.metrics,
                )
                self.metrics = self.evaluator.update("loss",
                                                     loss.item() / scale,
                                                     self.metrics)

                # Log metrics.
                log_dict = {k: self.metrics[k]["metric"] for k in self.metrics}
                log_dict.update({
                    "lr": self.scheduler.get_lr(),
                    "epoch": current_epoch,
                    "step": current_step,
                })
                if (current_step % self.config["cmd"]["print_every"] == 0
                        and distutils.is_master() and not self.is_hpo):
                    log_str = [
                        "{}: {:.2e}".format(k, v) for k, v in log_dict.items()
                    ]
                    print(", ".join(log_str))
                    self.metrics = {}

                if self.logger is not None:
                    self.logger.log(
                        log_dict,
                        step=current_step,
                        split="train",
                    )

                iters += 1

                # Evaluate on val set every `eval_every` iterations.
                if iters % eval_every == 0:
                    if self.val_loader is not None:
                        val_metrics = self.validate(
                            split="val",
                            epoch=epoch - 1 + (i + 1) / len(self.train_loader),
                        )
                        if ("mae" in primary_metric
                                and val_metrics[primary_metric]["metric"] <
                                self.best_val_metric) or (
                                    val_metrics[primary_metric]["metric"] >
                                    self.best_val_metric):
                            self.best_val_metric = val_metrics[primary_metric][
                                "metric"]
                            self.save(current_epoch, current_step, val_metrics)
                            if self.test_loader is not None:
                                self.predict(
                                    self.test_loader,
                                    results_file="predictions",
                                    disable_tqdm=False,
                                )

                        if self.is_hpo:
                            self.hpo_update(
                                current_epoch,
                                current_step,
                                self.metrics,
                                val_metrics,
                            )

                    else:
                        self.save(current_epoch, current_step, self.metrics)

                if self.scheduler.scheduler_type == "ReduceLROnPlateau":
                    if iters % eval_every == 0:
                        self.scheduler.step(
                            metrics=val_metrics[primary_metric]["metric"], )
                else:
                    self.scheduler.step()

            torch.cuda.empty_cache()

        self.train_dataset.close_db()
        if "val_dataset" in self.config:
            self.val_dataset.close_db()
        if "test_dataset" in self.config:
            self.test_dataset.close_db()
Ejemplo n.º 21
0
    def load_task(self):
        print("### Loading dataset: {}".format(self.config["task"]["dataset"]))

        self.parallel_collater = ParallelCollater(
            1 if not self.cpu else 0,
            self.config["model_attributes"].get("otf_graph", False),
        )
        if self.config["task"]["dataset"] == "trajectory_lmdb":
            self.train_dataset = registry.get_dataset_class(
                self.config["task"]["dataset"])(self.config["dataset"])

            self.train_loader = DataLoader(
                self.train_dataset,
                batch_size=self.config["optim"]["batch_size"],
                shuffle=True,
                collate_fn=self.parallel_collater,
                num_workers=self.config["optim"]["num_workers"],
                pin_memory=True,
            )

            self.val_loader = self.test_loader = None

            if "val_dataset" in self.config:
                self.val_dataset = registry.get_dataset_class(
                    self.config["task"]["dataset"])(self.config["val_dataset"])
                self.val_loader = DataLoader(
                    self.val_dataset,
                    self.config["optim"].get("eval_batch_size", 64),
                    shuffle=False,
                    collate_fn=self.parallel_collater,
                    num_workers=self.config["optim"]["num_workers"],
                    pin_memory=True,
                )
            if "test_dataset" in self.config:
                self.test_dataset = registry.get_dataset_class(
                    self.config["task"]["dataset"])(
                        self.config["test_dataset"])
                self.test_loader = DataLoader(
                    self.test_dataset,
                    self.config["optim"].get("eval_batch_size", 64),
                    shuffle=False,
                    collate_fn=self.parallel_collater,
                    num_workers=self.config["optim"]["num_workers"],
                    pin_memory=True,
                )

            if "relax_dataset" in self.config["task"]:
                assert os.path.isfile(
                    self.config["task"]["relax_dataset"]["src"])

                self.relax_dataset = registry.get_dataset_class(
                    "single_point_lmdb")(self.config["task"]["relax_dataset"])

                self.relax_sampler = DistributedSampler(
                    self.relax_dataset,
                    num_replicas=distutils.get_world_size(),
                    rank=distutils.get_rank(),
                    shuffle=False,
                )
                self.relax_loader = DataLoader(
                    self.relax_dataset,
                    batch_size=self.config["optim"].get("eval_batch_size", 64),
                    collate_fn=self.parallel_collater,
                    num_workers=self.config["optim"]["num_workers"],
                    pin_memory=True,
                    sampler=self.relax_sampler,
                )

        else:
            self.dataset = registry.get_dataset_class(
                self.config["task"]["dataset"])(self.config["dataset"])
            (
                self.train_loader,
                self.val_loader,
                self.test_loader,
            ) = self.dataset.get_dataloaders(
                batch_size=self.config["optim"]["batch_size"],
                collate_fn=self.parallel_collater,
            )

        self.num_targets = 1

        # Normalizer for the dataset.
        # Compute mean, std of training set labels.
        self.normalizers = {}
        if self.config["dataset"].get("normalize_labels", False):
            if "target_mean" in self.config["dataset"]:
                self.normalizers["target"] = Normalizer(
                    mean=self.config["dataset"]["target_mean"],
                    std=self.config["dataset"]["target_std"],
                    device=self.device,
                )
            else:
                self.normalizers["target"] = Normalizer(
                    tensor=self.train_loader.dataset.data.y[
                        self.train_loader.dataset.__indices__],
                    device=self.device,
                )

        # If we're computing gradients wrt input, set mean of normalizer to 0 --
        # since it is lost when compute dy / dx -- and std to forward target std
        if self.config["model_attributes"].get("regress_forces", True):
            if self.config["dataset"].get("normalize_labels", False):
                if "grad_target_mean" in self.config["dataset"]:
                    self.normalizers["grad_target"] = Normalizer(
                        mean=self.config["dataset"]["grad_target_mean"],
                        std=self.config["dataset"]["grad_target_std"],
                        device=self.device,
                    )
                else:
                    self.normalizers["grad_target"] = Normalizer(
                        tensor=self.train_loader.dataset.data.y[
                            self.train_loader.dataset.__indices__],
                        device=self.device,
                    )
                    self.normalizers["grad_target"].mean.fill_(0)

        if (self.is_vis and self.config["task"]["dataset"] != "qm9"
                and distutils.is_master()):
            # Plot label distribution.
            plots = [
                plot_histogram(
                    self.train_loader.dataset.data.y.tolist(),
                    xlabel="{}/raw".format(self.config["task"]["labels"][0]),
                    ylabel="# Examples",
                    title="Split: train",
                ),
                plot_histogram(
                    self.val_loader.dataset.data.y.tolist(),
                    xlabel="{}/raw".format(self.config["task"]["labels"][0]),
                    ylabel="# Examples",
                    title="Split: val",
                ),
                plot_histogram(
                    self.test_loader.dataset.data.y.tolist(),
                    xlabel="{}/raw".format(self.config["task"]["labels"][0]),
                    ylabel="# Examples",
                    title="Split: test",
                ),
            ]
            self.logger.log_plots(plots)
Ejemplo n.º 22
0
 def save(self, epoch, step, metrics):
     if not self.is_debug and distutils.is_master() and not self.is_hpo:
         save_checkpoint(
             self.save_state(epoch, step, metrics),
             self.config["cmd"]["checkpoint_dir"],
         )
Ejemplo n.º 23
0
    def train(self, disable_eval_tqdm=False):
        eval_every = self.config["optim"].get("eval_every",
                                              len(self.train_loader))
        primary_metric = self.config["task"].get(
            "primary_metric", self.evaluator.task_primary_metric[self.name])
        self.best_val_mae = 1e9

        # Calculate start_epoch from step instead of loading the epoch number
        # to prevent inconsistencies due to different batch size in checkpoint.
        start_epoch = self.step // len(self.train_loader)

        for epoch_int in range(start_epoch,
                               self.config["optim"]["max_epochs"]):
            self.train_sampler.set_epoch(epoch_int)
            skip_steps = self.step % len(self.train_loader)
            train_loader_iter = iter(self.train_loader)

            for i in range(skip_steps, len(self.train_loader)):
                self.epoch = epoch_int + (i + 1) / len(self.train_loader)
                self.step = epoch_int * len(self.train_loader) + i + 1
                self.model.train()

                # Get a batch.
                batch = next(train_loader_iter)

                # Forward, loss, backward.
                with torch.cuda.amp.autocast(enabled=self.scaler is not None):
                    out = self._forward(batch)
                    loss = self._compute_loss(out, batch)
                loss = self.scaler.scale(loss) if self.scaler else loss
                self._backward(loss)
                scale = self.scaler.get_scale() if self.scaler else 1.0

                # Compute metrics.
                self.metrics = self._compute_metrics(
                    out,
                    batch,
                    self.evaluator,
                    metrics={},
                )
                self.metrics = self.evaluator.update("loss",
                                                     loss.item() / scale,
                                                     self.metrics)

                # Log metrics.
                log_dict = {k: self.metrics[k]["metric"] for k in self.metrics}
                log_dict.update({
                    "lr": self.scheduler.get_lr(),
                    "epoch": self.epoch,
                    "step": self.step,
                })
                if (self.step % self.config["cmd"]["print_every"] == 0
                        and distutils.is_master() and not self.is_hpo):
                    log_str = [
                        "{}: {:.2e}".format(k, v) for k, v in log_dict.items()
                    ]
                    print(", ".join(log_str))
                    self.metrics = {}

                if self.logger is not None:
                    self.logger.log(
                        log_dict,
                        step=self.step,
                        split="train",
                    )

                # Evaluate on val set after every `eval_every` iterations.
                if self.step % eval_every == 0:
                    self.save(checkpoint_file="checkpoint.pt",
                              training_state=True)

                    if self.val_loader is not None:
                        val_metrics = self.validate(
                            split="val",
                            disable_tqdm=disable_eval_tqdm,
                        )
                        if (val_metrics[self.evaluator.task_primary_metric[
                                self.name]]["metric"] < self.best_val_mae):
                            self.best_val_mae = val_metrics[
                                self.evaluator.task_primary_metric[
                                    self.name]]["metric"]
                            self.save(
                                metrics=val_metrics,
                                checkpoint_file="best_checkpoint.pt",
                                training_state=False,
                            )
                            if self.test_loader is not None:
                                self.predict(
                                    self.test_loader,
                                    results_file="predictions",
                                    disable_tqdm=False,
                                )

                        if self.is_hpo:
                            self.hpo_update(
                                self.epoch,
                                self.step,
                                self.metrics,
                                val_metrics,
                            )

                    else:
                        self.save(self.epoch, self.step, self.metrics)

                if self.scheduler.scheduler_type == "ReduceLROnPlateau":
                    if self.step % eval_every == 0:
                        self.scheduler.step(
                            metrics=val_metrics[primary_metric]["metric"], )
                else:
                    self.scheduler.step()

            torch.cuda.empty_cache()

        self.train_dataset.close_db()
        if "val_dataset" in self.config:
            self.val_dataset.close_db()
        if "test_dataset" in self.config:
            self.test_dataset.close_db()
Ejemplo n.º 24
0
    def __init__(
        self,
        task,
        model,
        dataset,
        optimizer,
        identifier,
        run_dir=None,
        is_debug=False,
        is_vis=False,
        print_every=100,
        seed=None,
        logger="tensorboard",
        local_rank=0,
        amp=False,
    ):

        if run_dir is None:
            run_dir = os.getcwd()

        timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        if identifier:
            timestamp += "-{}".format(identifier)

        self.config = {
            "task": task,
            "model": model.pop("name"),
            "model_attributes": model,
            "optim": optimizer,
            "logger": logger,
            "amp": amp,
            "cmd": {
                "identifier": identifier,
                "print_every": print_every,
                "seed": seed,
                "timestamp": timestamp,
                "checkpoint_dir": os.path.join(run_dir, "checkpoints",
                                               timestamp),
                "results_dir": os.path.join(run_dir, "results", timestamp),
                "logs_dir": os.path.join(run_dir, "logs", logger, timestamp),
            },
        }
        # AMP Scaler
        self.scaler = torch.cuda.amp.GradScaler() if amp else None

        if isinstance(dataset, list):
            self.config["dataset"] = dataset[0]
            if len(dataset) > 1:
                self.config["val_dataset"] = dataset[1]
        else:
            self.config["dataset"] = dataset

        if not is_debug and distutils.is_master():
            os.makedirs(self.config["cmd"]["checkpoint_dir"])
            os.makedirs(self.config["cmd"]["results_dir"])
            os.makedirs(self.config["cmd"]["logs_dir"])

        self.is_debug = is_debug
        self.is_vis = is_vis
        if torch.cuda.is_available():
            self.device = local_rank
        else:
            self.device = "cpu"

        if distutils.is_master():
            print(yaml.dump(self.config, default_flow_style=False))
        self.load()

        self.evaluator = Evaluator(task="s2ef")