Beispiel #1
0
def _train_model_step(model, observation, optimizer, mask, logger):
    if not isinstance(observation, Observation):
        observation = Observation(**observation)
    observation.action = observation.action[..., :model.dim_action[0]]
    if isinstance(model, EnsembleModel):
        loss = train_ensemble_step(model, observation, optimizer, mask)
    elif isinstance(model, NNModel):
        loss = train_nn_step(model, observation, optimizer)
    elif isinstance(model, ExactGPModel):
        loss = train_exact_gp_type2mll_step(model, observation, optimizer)
    else:
        raise TypeError("Only Implemented for Ensembles and GP Models.")
    logger.update(**{f"{model.model_kind[:3]}-loss": loss.item()})
Beispiel #2
0
def _validate_model_step(model, observation, logger):
    if not isinstance(observation, Observation):
        observation = Observation(**observation)
    observation.action = observation.action[..., :model.dim_action[0]]

    mse = model_mse(model, observation).item()
    sharpness_ = sharpness(model, observation).item()
    calibration_score_ = calibration_score(model, observation).item()

    logger.update(
        **{
            f"{model.model_kind[:3]}-val-mse": mse,
            f"{model.model_kind[:3]}-sharp": sharpness_,
            f"{model.model_kind[:3]}-calib": calibration_score_,
        })
    return mse
Beispiel #3
0
def train_model(
        model,
        train_set,
        optimizer,
        batch_size=100,
        max_iter=100,
        epsilon=0.1,
        non_decrease_iter=float("inf"),
        logger=None,
        validation_set=None,
):
    """Train a Predictive Model.

    Parameters
    ----------
    model: AbstractModel.
        Predictive model to optimize.
    train_set: ExperienceReplay.
        Dataset to train with.
    optimizer: Optimizer.
        Optimizer to call for the model.
    batch_size: int (default=1000).
        Batch size to iterate through.
    max_iter: int (default = 100).
        Maximum number of epochs.
    epsilon: float.
        Early stopping parameter. If epoch loss is > (1 + epsilon) of minimum loss the
        optimization process stops.
    non_decrease_iter: int, optional.
        Early stopping parameter. If epoch loss does not decrease for consecutive
        non_decrease_iter, the optimization process stops.
    logger: Logger, optional.
        Progress logger.
    validation_set: ExperienceReplay, optional.
        Dataset to validate with.
    """
    if logger is None:
        logger = Logger(f"{model.name}_training")
    if validation_set is None:
        validation_set = train_set

    model.train()
    early_stopping = EarlyStopping(epsilon,
                                   non_decrease_iter=non_decrease_iter)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    validation_loader = DataLoader(validation_set,
                                   batch_size=batch_size,
                                   shuffle=False)

    for _ in tqdm(range(max_iter)):
        for observation, idx, mask in train_loader:
            observation = Observation(**observation)
            observation.action = observation.action[..., :model.dim_action[0]]
            if isinstance(model, EnsembleModel):
                loss = train_ensemble_step(model, observation, optimizer, mask)
            elif isinstance(model, NNModel):
                loss = train_nn_step(model, observation, optimizer)
            elif isinstance(model, ExactGPModel):
                loss = train_exact_gp_type2mll_step(model, observation,
                                                    optimizer)
            else:
                raise TypeError(
                    "Only Implemented for Ensembles and GP Models.")
            logger.update(**{f"{model.model_kind[:3]}-loss": loss.item()})

        for observation, idx, mask in validation_loader:
            observation = Observation(**observation)
            observation.action = observation.action[..., :model.dim_action[0]]

            with torch.no_grad():
                mse = model_mse(model, observation).item()
                sharpness_ = sharpness(model, observation).item()
                calibration_score_ = calibration_score(model,
                                                       observation).item()

            logger.update(
                **{
                    f"{model.model_kind[:3]}-val-mse": mse,
                    f"{model.model_kind[:3]}-sharp": sharpness_,
                    f"{model.model_kind[:3]}-calib": calibration_score_,
                })

            early_stopping.update(mse)

        if early_stopping.stop:
            return
        early_stopping.reset(hard=False)  # reset to zero the moving averages.