def save(self, file_path): state = { 'model_name': self.__class__.__name__, 'params': self.params, 'nn_state_dict': to_device(self.nn_module.state_dict(), 'cpu') } torch.save(state, file_path) self.logger.info(f"Model saved to '{file_path}'")
def predict(self, input): assert self.predict_ready() with torch.no_grad(): if self.nn_module.training: self.nn_module.eval() input = to_device(input, self.device) prediction = self.nn_module(input) prediction = self.prediction_transform(prediction) return prediction
def load_model(file_path, device=None): if os.path.isfile(file_path): state = torch.load(file_path) if state['model_name'] in MODEL_REGISTRY: params = state['params'] if device is not None: device = torch.device(device).type params['device'] = device model_class = MODEL_REGISTRY[state['model_name']] model = model_class(params) nn_state_dict = to_device(state['nn_state_dict'], model.device) model.nn_module.load_state_dict(nn_state_dict) model.nn_module.eval() return model else: raise ImportError( f"Model '{state['model_name']}' not found in scope") else: raise FileNotFoundError(f"No state found at {file_path}")
def prepare_batch(self, batch, device): inp, trg = batch return to_device(inp, device), to_device(trg, device)
def prepare_unlabeled_batch(self, batch, device): input, trg = batch unlabeled_input = self.sample_unlabeled_input() input = torch.cat([input, unlabeled_input], dim=0) return to_device(input, device), to_device(trg, device)