def _train_model_step(model, observation, optimizer, mask, logger): if not isinstance(observation, Observation): observation = Observation(**observation) observation.action = observation.action[..., :model.dim_action[0]] if isinstance(model, EnsembleModel): loss = train_ensemble_step(model, observation, optimizer, mask) elif isinstance(model, NNModel): loss = train_nn_step(model, observation, optimizer) elif isinstance(model, ExactGPModel): loss = train_exact_gp_type2mll_step(model, observation, optimizer) else: raise TypeError("Only Implemented for Ensembles and GP Models.") logger.update(**{f"{model.model_kind[:3]}-loss": loss.item()})
def _validate_model_step(model, observation, logger): if not isinstance(observation, Observation): observation = Observation(**observation) observation.action = observation.action[..., :model.dim_action[0]] mse = model_mse(model, observation).item() sharpness_ = sharpness(model, observation).item() calibration_score_ = calibration_score(model, observation).item() logger.update( **{ f"{model.model_kind[:3]}-val-mse": mse, f"{model.model_kind[:3]}-sharp": sharpness_, f"{model.model_kind[:3]}-calib": calibration_score_, }) return mse
def train_model( model, train_set, optimizer, batch_size=100, max_iter=100, epsilon=0.1, non_decrease_iter=float("inf"), logger=None, validation_set=None, ): """Train a Predictive Model. Parameters ---------- model: AbstractModel. Predictive model to optimize. train_set: ExperienceReplay. Dataset to train with. optimizer: Optimizer. Optimizer to call for the model. batch_size: int (default=1000). Batch size to iterate through. max_iter: int (default = 100). Maximum number of epochs. epsilon: float. Early stopping parameter. If epoch loss is > (1 + epsilon) of minimum loss the optimization process stops. non_decrease_iter: int, optional. Early stopping parameter. If epoch loss does not decrease for consecutive non_decrease_iter, the optimization process stops. logger: Logger, optional. Progress logger. validation_set: ExperienceReplay, optional. Dataset to validate with. """ if logger is None: logger = Logger(f"{model.name}_training") if validation_set is None: validation_set = train_set model.train() early_stopping = EarlyStopping(epsilon, non_decrease_iter=non_decrease_iter) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) validation_loader = DataLoader(validation_set, batch_size=batch_size, shuffle=False) for _ in tqdm(range(max_iter)): for observation, idx, mask in train_loader: observation = Observation(**observation) observation.action = observation.action[..., :model.dim_action[0]] if isinstance(model, EnsembleModel): loss = train_ensemble_step(model, observation, optimizer, mask) elif isinstance(model, NNModel): loss = train_nn_step(model, observation, optimizer) elif isinstance(model, ExactGPModel): loss = train_exact_gp_type2mll_step(model, observation, optimizer) else: raise TypeError( "Only Implemented for Ensembles and GP Models.") logger.update(**{f"{model.model_kind[:3]}-loss": loss.item()}) for observation, idx, mask in validation_loader: observation = Observation(**observation) observation.action = observation.action[..., :model.dim_action[0]] with torch.no_grad(): mse = model_mse(model, observation).item() sharpness_ = sharpness(model, observation).item() calibration_score_ = calibration_score(model, observation).item() logger.update( **{ f"{model.model_kind[:3]}-val-mse": mse, f"{model.model_kind[:3]}-sharp": sharpness_, f"{model.model_kind[:3]}-calib": calibration_score_, }) early_stopping.update(mse) if early_stopping.stop: return early_stopping.reset(hard=False) # reset to zero the moving averages.