def load_model(self): super(EnergyTrainer, self).load_model() self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=self.config["optim"].get("num_gpus", 1), ) if distutils.initialized(): self.model = DistributedDataParallel(self.model, device_ids=[self.device])
def load_model(self): # Build model if distutils.is_master(): logging.info(f"Loading model: {self.config['model']}") # TODO(abhshkdz): Eventually move towards computing features on-the-fly # and remove dependence from `.edge_attr`. bond_feat_dim = None if self.config["task"]["dataset"] in [ "trajectory_lmdb", "single_point_lmdb", ]: bond_feat_dim = self.config["model_attributes"].get( "num_gaussians", 50 ) else: raise NotImplementedError loader = self.train_loader or self.val_loader or self.test_loader self.model = registry.get_model_class(self.config["model"])( loader.dataset[0].x.shape[-1] if loader and hasattr(loader.dataset[0], "x") and loader.dataset[0].x is not None else None, bond_feat_dim, self.num_targets, **self.config["model_attributes"], ).to(self.device) if distutils.is_master(): logging.info( f"Loaded {self.model.__class__.__name__} with " f"{self.model.num_params} parameters." ) if self.logger is not None: self.logger.watch(self.model) self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=1 if not self.cpu else 0, ) if distutils.initialized(): self.model = DistributedDataParallel( self.model, device_ids=[self.device] )
def load_loss(self): self.loss_fn = {} self.loss_fn["energy"] = self.config["optim"].get("loss_energy", "mae") self.loss_fn["force"] = self.config["optim"].get("loss_force", "mae") for loss, loss_name in self.loss_fn.items(): if loss_name in ["l1", "mae"]: self.loss_fn[loss] = nn.L1Loss() elif loss_name == "mse": self.loss_fn[loss] = nn.MSELoss() elif loss_name == "l2mae": self.loss_fn[loss] = L2MAELoss() else: raise NotImplementedError( f"Unknown loss function name: {loss_name}" ) if distutils.initialized(): self.loss_fn[loss] = DDPLoss(self.loss_fn[loss])
def load_checkpoint(self, checkpoint_path): if not os.path.isfile(checkpoint_path): raise FileNotFoundError( errno.ENOENT, "Checkpoint file not found", checkpoint_path ) logging.info(f"Loading checkpoint from: {checkpoint_path}") map_location = torch.device("cpu") if self.cpu else self.device checkpoint = torch.load(checkpoint_path, map_location=map_location) self.epoch = checkpoint.get("epoch", 0) self.step = checkpoint.get("step", 0) # Load model, optimizer, normalizer state dict. # if trained with ddp and want to load in non-ddp, modify keys from # module.module.. -> module.. first_key = next(iter(checkpoint["state_dict"])) if not distutils.initialized() and first_key.split(".")[1] == "module": # No need for OrderedDict since dictionaries are technically ordered # since Python 3.6 and officially ordered since Python 3.7 new_dict = {k[7:]: v for k, v in checkpoint["state_dict"].items()} self.model.load_state_dict(new_dict) else: self.model.load_state_dict(checkpoint["state_dict"]) if "optimizer" in checkpoint: self.optimizer.load_state_dict(checkpoint["optimizer"]) if "scheduler" in checkpoint and checkpoint["scheduler"] is not None: self.scheduler.scheduler.load_state_dict(checkpoint["scheduler"]) if "ema" in checkpoint and checkpoint["ema"] is not None: self.ema.load_state_dict(checkpoint["ema"]) for key in checkpoint["normalizers"]: if key in self.normalizers: self.normalizers[key].load_state_dict( checkpoint["normalizers"][key] ) if self.scaler and checkpoint["amp"]: self.scaler.load_state_dict(checkpoint["amp"])