Exemple #1
0
class AtomsTrainer:
    def __init__(self, config):
        self.config = config
        self.pretrained = False

    def load(self):
        self.load_config()
        self.load_rng_seed()
        self.load_dataset()
        self.load_model()
        self.load_criterion()
        self.load_optimizer()
        self.load_logger()
        self.load_extras()
        self.load_skorch()

    def load_config(self):
        self.timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        self.identifier = self.config["cmd"].get("identifier", False)
        if self.identifier:
            self.identifier = self.timestamp + "-{}".format(self.identifier)
        else:
            self.identifier = self.timestamp

        self.device = torch.device(self.config["optim"].get("device", "cpu"))
        self.debug = self.config["cmd"].get("debug", False)
        run_dir = self.config["cmd"].get("run_dir", "./")
        os.chdir(run_dir)
        if not self.debug:
            self.cp_dir = os.path.join(run_dir, "checkpoints", self.identifier)
            print(f"Results saved to {self.cp_dir}")
            os.makedirs(self.cp_dir, exist_ok=True)

    def load_rng_seed(self):
        seed = self.config["cmd"].get("seed", 0)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    def get_unique_elements(self, training_images):
        elements = np.array(
            [atom.symbol for atoms in training_images for atom in atoms])
        elements = np.unique(elements)
        return elements

    def load_dataset(self):
        training_images = self.config["dataset"]["raw_data"]
        # TODO: Scalability when dataset to large to fit into memory
        if isinstance(training_images, str):
            training_images = ase.io.read(training_images, ":")
        self.elements = self.config["dataset"].get(
            "elements", self.get_unique_elements(training_images))

        self.forcetraining = self.config["model"].get("get_forces", True)
        self.fp_scheme = self.config["dataset"].get("fp_scheme",
                                                    "gaussian").lower()
        self.fp_params = self.config["dataset"]["fp_params"]
        self.cutoff_params = self.config["dataset"].get(
            "cutoff_params", {"cutoff_func": "Cosine"})

        self.train_dataset = AtomsDataset(
            images=training_images,
            descriptor_setup=(
                self.fp_scheme,
                self.fp_params,
                self.cutoff_params,
                self.elements,
            ),
            forcetraining=self.forcetraining,
            save_fps=self.config["dataset"].get("save_fps", True),
        )

        self.target_scaler = self.train_dataset.target_scaler
        if not self.debug:
            normalizers = {"target": self.target_scaler}
            torch.save(normalizers, os.path.join(self.cp_dir,
                                                 "normalizers.pt"))
        self.input_dim = self.train_dataset.input_dim
        self.val_split = self.config["dataset"].get("val_split", 0)
        print("Loading dataset: {} images".format(len(self.train_dataset)))

    def load_model(self):
        elements = list_symbols_to_indices(self.elements)
        self.model = BPNN(elements=elements,
                          input_dim=self.input_dim,
                          **self.config["model"])
        print("Loading model: {} parameters".format(self.model.num_params))

    def load_extras(self):
        callbacks = []
        load_best_loss = train_end_load_best_loss(self.identifier)
        self.split = CVSplit(cv=self.val_split) if self.val_split != 0 else 0

        metrics = evaluator(
            self.val_split,
            self.config["optim"].get("metric", "mae"),
            self.identifier,
            self.forcetraining,
        )
        callbacks.extend(metrics)

        if not self.debug:
            callbacks.append(load_best_loss)
        scheduler = self.config["optim"].get("scheduler", None)
        if scheduler:
            scheduler = LRScheduler(scheduler,
                                    **self.config["optim"]["scheduler_params"])
            callbacks.append(scheduler)
        if self.config["cmd"].get("logger", False):
            from skorch.callbacks import WandbLogger

            callbacks.append(
                WandbLogger(
                    self.wandb_run,
                    save_model=False,
                    keys_ignored="dur",
                ))
        self.callbacks = callbacks

    def load_criterion(self):
        self.criterion = self.config["optim"].get("loss_fn", CustomLoss)

    def load_optimizer(self):
        self.optimizer = self.config["optim"].get("optimizer",
                                                  torch.optim.Adam)

    def load_logger(self):
        if self.config["cmd"].get("logger", False):
            import wandb

            self.wandb_run = wandb.init(
                name=self.identifier,
                config=self.config,
                id=self.timestamp,
            )

    def load_skorch(self):
        skorch.net.to_tensor = to_tensor

        collate_fn = DataCollater(train=True, forcetraining=self.forcetraining)

        self.net = NeuralNetRegressor(
            module=self.model,
            criterion=self.criterion,
            criterion__force_coefficient=self.config["optim"].get(
                "force_coefficient", 0),
            criterion__loss=self.config["optim"].get("loss", "mse"),
            optimizer=self.optimizer,
            lr=self.config["optim"].get("lr", 1e-1),
            batch_size=self.config["optim"].get("batch_size", 32),
            max_epochs=self.config["optim"].get("epochs", 100),
            iterator_train__collate_fn=collate_fn,
            iterator_train__shuffle=True,
            iterator_valid__collate_fn=collate_fn,
            iterator_valid__shuffle=False,
            device=self.device,
            train_split=self.split,
            callbacks=self.callbacks,
            verbose=self.config["cmd"].get("verbose", True),
        )
        print("Loading skorch trainer")

    def train(self, raw_data=None):
        if raw_data is not None:
            self.config["dataset"]["raw_data"] = raw_data
        if not self.pretrained:
            self.load()

        self.net.fit(self.train_dataset, None)

    def predict(self, images, batch_size=32):
        if len(images) < 1:
            warnings.warn("No images found!", stacklevel=2)
            return images

        a2d = AtomsToData(
            descriptor=self.train_dataset.descriptor,
            r_energy=False,
            r_forces=False,
            save_fps=True,
            fprimes=self.forcetraining,
            cores=1,
        )

        data_list = a2d.convert_all(images, disable_tqdm=True)

        self.net.module.eval()
        collate_fn = DataCollater(train=False,
                                  forcetraining=self.forcetraining)

        predictions = {"energy": [], "forces": []}
        for data in data_list:
            collated = collate_fn([data])
            energy, forces = self.net.module(collated)

            energy = self.target_scaler.denorm(
                energy, pred="energy").detach().tolist()
            forces = self.target_scaler.denorm(forces,
                                               pred="forces").detach().numpy()

            predictions["energy"].extend(energy)
            predictions["forces"].append(forces)

        return predictions

    def load_pretrained(self, checkpoint_path=None):
        print(f"Loading checkpoint from {checkpoint_path}")
        self.load()
        self.net.initialize()
        self.pretrained = True
        try:
            self.net.load_params(
                f_params=os.path.join(checkpoint_path, "params.pt"),
                f_optimizer=os.path.join(checkpoint_path, "optimizer.pt"),
                f_criterion=os.path.join(checkpoint_path, "criterion.pt"),
                f_history=os.path.join(checkpoint_path, "history.json"),
            )
            # TODO(mshuaibi): remove dataset load, use saved normalizers
        except NotImplementedError:
            print("Unable to load checkpoint!")
Exemple #2
0
class AtomsTrainer:
    def __init__(self, config={}):
        self.config = config
        self.pretrained = False

    def load(self, load_dataset=True):
        self.load_config()
        self.load_rng_seed()
        if load_dataset:
            self.load_dataset()
        self.load_model()
        self.load_criterion()
        self.load_optimizer()
        self.load_logger()
        self.load_extras()
        self.load_skorch()

    def load_config(self):
        dtype = self.config["cmd"].get("dtype", torch.FloatTensor)
        torch.set_default_tensor_type(dtype)
        self.timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        self.identifier = self.config["cmd"].get("identifier", False)
        if self.identifier:
            self.identifier = self.timestamp + "-{}".format(self.identifier)
        else:
            self.identifier = self.timestamp

        self.gpus = self.config["optim"].get("gpus", 0)
        if self.gpus > 0:
            self.output_device = 0
            self.device = f"cuda:{self.output_device}"
        else:
            self.device = "cpu"
            self.output_device = -1
        self.debug = self.config["cmd"].get("debug", False)
        run_dir = self.config["cmd"].get("run_dir", "./")
        os.chdir(run_dir)
        if not self.debug:
            self.cp_dir = os.path.join(run_dir, "checkpoints", self.identifier)
            print(f"Results saved to {self.cp_dir}")
            os.makedirs(self.cp_dir, exist_ok=True)

    def load_rng_seed(self):
        seed = self.config["cmd"].get("seed", 0)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    def get_unique_elements(self, training_images):
        elements = np.array(
            [atom.symbol for atoms in training_images for atom in atoms]
        )
        elements = np.unique(elements)
        return elements

    def load_dataset(self):
        training_images = self.config["dataset"]["raw_data"]
        # TODO: Scalability when dataset to large to fit into memory
        if isinstance(training_images, str):
            training_images = ase.io.read(training_images, ":")
        del self.config["dataset"]["raw_data"]

        self.elements = self.config["dataset"].get(
            "elements", self.get_unique_elements(training_images)
        )

        self.forcetraining = self.config["model"].get("get_forces", True)
        self.fp_scheme = self.config["dataset"].get("fp_scheme", "gaussian").lower()
        self.fp_params = self.config["dataset"]["fp_params"]
        self.save_fps = self.config["dataset"].get("save_fps", True)
        self.cutoff_params = self.config["dataset"].get(
            "cutoff_params", {"cutoff_func": "Cosine"}
        )
        descriptor_setup = (
            self.fp_scheme,
            self.fp_params,
            self.cutoff_params,
            self.elements,
        )
        self.train_dataset = AtomsDataset(
            images=training_images,
            descriptor_setup=descriptor_setup,
            forcetraining=self.forcetraining,
            save_fps=self.config["dataset"].get("save_fps", True),
            scaling=self.config["dataset"].get(
                "scaling", {"type": "normalize", "range": (0, 1)}
            ),
        )
        self.feature_scaler = self.train_dataset.feature_scaler
        self.target_scaler = self.train_dataset.target_scaler
        self.input_dim = self.train_dataset.input_dim
        self.val_split = self.config["dataset"].get("val_split", 0)
        self.config["dataset"]["descriptor"] = descriptor_setup
        if not self.debug:
            normalizers = {
                "target": self.target_scaler,
                "feature": self.feature_scaler,
            }
            torch.save(normalizers, os.path.join(self.cp_dir, "normalizers.pt"))
            # clean/organize config
            self.config["dataset"]["fp_length"] = self.input_dim
            torch.save(self.config, os.path.join(self.cp_dir, "config.pt"))
        print("Loading dataset: {} images".format(len(self.train_dataset)))

    def load_model(self):
        elements = list_symbols_to_indices(self.elements)
        self.model = BPNN(
            elements=elements, input_dim=self.input_dim, **self.config["model"]
        )
        print("Loading model: {} parameters".format(self.model.num_params))
        self.forcetraining = self.config["model"].get("get_forces", True)
        collate_fn = DataCollater(train=True, forcetraining=self.forcetraining)
        self.parallel_collater = ParallelCollater(self.gpus, collate_fn)
        if self.gpus > 0:
            self.model = DataParallel(
                self.model,
                output_device=self.output_device,
                num_gpus=self.gpus,
            )

    def load_extras(self):
        callbacks = []
        load_best_loss = train_end_load_best_loss(self.identifier)
        self.val_split = self.config["dataset"].get("val_split", 0)
        self.split = CVSplit(cv=self.val_split) if self.val_split != 0 else 0

        metrics = evaluator(
            self.val_split,
            self.config["optim"].get("metric", "mae"),
            self.identifier,
            self.forcetraining,
        )
        callbacks.extend(metrics)

        if not self.debug:
            callbacks.append(load_best_loss)
        scheduler = self.config["optim"].get("scheduler", None)
        if scheduler:
            scheduler = LRScheduler(scheduler["policy"], **scheduler["params"])
            callbacks.append(scheduler)
        if self.config["cmd"].get("logger", False):
            from skorch.callbacks import WandbLogger

            callbacks.append(
                WandbLogger(
                    self.wandb_run,
                    save_model=False,
                    keys_ignored="dur",
                )
            )
        self.callbacks = callbacks

    def load_criterion(self):
        self.criterion = self.config["optim"].get("loss_fn", CustomLoss)

    def load_optimizer(self):
        self.optimizer = {
            "optimizer": self.config["optim"].get("optimizer", torch.optim.Adam)
        }
        optimizer_args = self.config["optim"].get("optimizer_args", False)
        if optimizer_args:
            self.optimizer.update(optimizer_args)

    def load_logger(self):
        if self.config["cmd"].get("logger", False):
            import wandb

            self.wandb_run = wandb.init(
                name=self.identifier,
                config=self.config,
            )

    def load_skorch(self):
        skorch.net.to_tensor = to_tensor

        self.net = NeuralNetRegressor(
            module=self.model,
            criterion=self.criterion,
            criterion__force_coefficient=self.config["optim"].get(
                "force_coefficient", 0
            ),
            criterion__loss=self.config["optim"].get("loss", "mse"),
            lr=self.config["optim"].get("lr", 1e-1),
            batch_size=self.config["optim"].get("batch_size", 32),
            max_epochs=self.config["optim"].get("epochs", 100),
            iterator_train__collate_fn=self.parallel_collater,
            iterator_train__shuffle=True,
            iterator_train__pin_memory=True,
            iterator_valid__collate_fn=self.parallel_collater,
            iterator_valid__shuffle=False,
            iterator_valid__pin_memory=True,
            device=self.device,
            train_split=self.split,
            callbacks=self.callbacks,
            verbose=self.config["cmd"].get("verbose", True),
            **self.optimizer,
        )
        print("Loading skorch trainer")

    def train(self, raw_data=None):
        if raw_data is not None:
            self.config["dataset"]["raw_data"] = raw_data
        if not self.pretrained:
            self.load()

        stime = time.time()
        self.net.fit(self.train_dataset, None)
        elapsed_time = time.time() - stime
        print(f"Training completed in {elapsed_time}s")

    def predict(self, images, disable_tqdm=True):
        if len(images) < 1:
            warnings.warn("No images found!", stacklevel=2)
            return images

        self.descriptor = construct_descriptor(self.config["dataset"]["descriptor"])

        a2d = AtomsToData(
            descriptor=self.descriptor,
            r_energy=False,
            r_forces=False,
            save_fps=self.config["dataset"].get("save_fps", True),
            fprimes=self.forcetraining,
            cores=1,
        )

        data_list = a2d.convert_all(images, disable_tqdm=disable_tqdm)
        self.feature_scaler.norm(data_list, disable_tqdm=disable_tqdm)

        self.net.module.eval()
        collate_fn = DataCollater(train=False, forcetraining=self.forcetraining)

        predictions = {"energy": [], "forces": []}
        for data in data_list:
            collated = collate_fn([data]).to(self.device)
            energy, forces = self.net.module([collated])

            energy = self.target_scaler.denorm(
                energy.detach().cpu(), pred="energy"
            ).tolist()
            forces = self.target_scaler.denorm(
                forces.detach().cpu(), pred="forces"
            ).numpy()

            predictions["energy"].extend(energy)
            predictions["forces"].append(forces)

        return predictions

    def load_pretrained(self, checkpoint_path=None, gpu2cpu=False):
        """
        Args:
            checkpoint_path: str, Path to checkpoint directory
            gpu2cpu: bool, True if checkpoint was trained with GPUs and you
            wish to load on cpu instead.
        """

        self.pretrained = True
        print(f"Loading checkpoint from {checkpoint_path}")
        assert os.path.isdir(
            checkpoint_path
        ), f"Checkpoint: {checkpoint_path} not found!"
        if not self.config:
            # prediction only
            self.config = torch.load(os.path.join(checkpoint_path, "config.pt"))
            self.config["cmd"]["debug"] = True
            self.elements = self.config["dataset"]["descriptor"][-1]
            self.input_dim = self.config["dataset"]["fp_length"]
            if gpu2cpu:
                self.config["optim"]["gpus"] = 0
            self.load(load_dataset=False)
        else:
            # prediction+retraining
            self.load(load_dataset=True)
        self.net.initialize()

        if gpu2cpu:
            params_path = os.path.join(checkpoint_path, "params_cpu.pt")
            if not os.path.exists(params_path):
                params = torch.load(
                    os.path.join(checkpoint_path, "params.pt"),
                    map_location=torch.device("cpu"),
                )
                new_dict = OrderedDict()
                for k, v in params.items():
                    name = k[7:]
                    new_dict[name] = v
                torch.save(new_dict, params_path)
        else:
            params_path = os.path.join(checkpoint_path, "params.pt")

        try:
            self.net.load_params(
                f_params=params_path,
                f_optimizer=os.path.join(checkpoint_path, "optimizer.pt"),
                f_criterion=os.path.join(checkpoint_path, "criterion.pt"),
                f_history=os.path.join(checkpoint_path, "history.json"),
            )
            normalizers = torch.load(os.path.join(checkpoint_path, "normalizers.pt"))
            self.feature_scaler = normalizers["feature"]
            self.target_scaler = normalizers["target"]
        except NotImplementedError:
            print("Unable to load checkpoint!")

    def get_calc(self):
        return AMPtorch(self)