Exemplo n.º 1
0
 def test_early_stopper3(self):
     e = EarlyStopper(3, .2, True)
     n = NaiveBase(2, 2)
     e.check_loss(n, .9)
     e.check_loss(n, 1.1)
     e.check_loss(n, 1.2)
     self.assertFalse(e.check_loss(n, .8))
Exemplo n.º 2
0
 def test_early_stopper(self):
     e = EarlyStopper(3, .2)
     n = NaiveBase(2, 2)
     e.check_loss(n, .9)
     e.check_loss(n, .8)
     e.check_loss(n, .9)
     self.assertFalse(e.check_loss(n, .75))
Exemplo n.º 3
0
def train_transformer_style(model: PyTorchForecast,
                            training_params: Dict,
                            takes_target=False,
                            forward_params: Dict = {},
                            model_filepath: str = "model_save") -> None:
    """Function to train any PyTorchForecast model

    :param model:  A properly wrapped PyTorchForecast model
    :type model: PyTorchForecast
    :param training_params: A dictionary of the necessary parameters for training.
    :type training_params: Dict
    :param takes_target: A parameter to determine whether a model requires the target, defaults to False
    :type takes_target: bool, optional
    :param forward_params: [description], defaults to {}
    :type forward_params: Dict, optional
    :param model_filepath: The file path to load modeel weights from, defaults to "model_save"
    :type model_filepath: str, optional
    :raises ValueError: Has an error
    """
    use_wandb = model.wandb
    es = None
    worker_num = 1
    pin_memory = False
    dataset_params = model.params["dataset_params"]
    num_targets = 1
    if "n_targets" in model.params:
        num_targets = model.params["n_targets"]
    if "num_workers" in dataset_params:
        worker_num = dataset_params["num_workers"]
        print("using " + str(worker_num))
    if "pin_memory" in dataset_params:
        pin_memory = dataset_params["pin_memory"]
        print("Pin memory set to true")
    if "early_stopping" in model.params:
        es = EarlyStopper(model.params["early_stopping"]['patience'])
    opt = pytorch_opt_dict[training_params["optimizer"]](
        model.model.parameters(), **training_params["optim_params"])
    criterion_init_params = {}
    if "criterion_params" in training_params:
        criterion_init_params = training_params["criterion_params"]
    criterion = pytorch_criterion_dict[training_params["criterion"]](
        **criterion_init_params)
    if "probabilistic" in model.params[
            "model_params"] or "probabilistic" in model.params:
        probabilistic = True
    else:
        probabilistic = False
    max_epochs = training_params["epochs"]
    data_loader = DataLoader(model.training,
                             batch_size=training_params["batch_size"],
                             shuffle=False,
                             sampler=None,
                             batch_sampler=None,
                             num_workers=worker_num,
                             collate_fn=None,
                             pin_memory=pin_memory,
                             drop_last=False,
                             timeout=0,
                             worker_init_fn=None)
    validation_data_loader = DataLoader(
        model.validation,
        batch_size=training_params["batch_size"],
        shuffle=False,
        sampler=None,
        batch_sampler=None,
        num_workers=worker_num,
        collate_fn=None,
        pin_memory=pin_memory,
        drop_last=False,
        timeout=0,
        worker_init_fn=None)
    # TODO support batch_size > 1
    test_data_loader = DataLoader(model.test_data,
                                  batch_size=1,
                                  shuffle=False,
                                  sampler=None,
                                  batch_sampler=None,
                                  num_workers=worker_num,
                                  collate_fn=None,
                                  pin_memory=pin_memory,
                                  drop_last=False,
                                  timeout=0,
                                  worker_init_fn=None)
    meta_model = None
    meta_representation = None
    meta_loss = None
    if model.params.get("meta_data") is None:
        model.params["meta_data"] = False
    if model.params["meta_data"]:
        meta_model, meta_representation, meta_loss = handle_meta_data(model)
    if use_wandb:
        wandb.watch(model.model)
    use_decoder = False
    if "use_decoder" in model.params:
        use_decoder = True
    session_params = []
    for epoch in range(max_epochs):
        total_loss = torch_single_train(model,
                                        opt,
                                        criterion,
                                        data_loader,
                                        takes_target,
                                        meta_model,
                                        meta_representation,
                                        meta_loss,
                                        multi_targets=num_targets,
                                        forward_params=forward_params.copy())
        print("The loss for epoch " + str(epoch))
        print(total_loss)
        valid = compute_validation(
            validation_data_loader,
            model.model,
            epoch,
            model.params["dataset_params"]["forecast_length"],
            model.crit,
            model.device,
            multi_targets=num_targets,
            meta_model=meta_model,
            decoder_structure=use_decoder,
            use_wandb=use_wandb,
            probabilistic=probabilistic)
        if valid == 0.0:
            raise ValueError(
                "Error validation loss is zero there is a problem with the validator."
            )
        if use_wandb:
            wandb.log({'epoch': epoch, 'loss': total_loss})
        epoch_params = {
            "epoch": epoch,
            "train_loss": str(total_loss),
            "validation_loss": str(valid)
        }
        session_params.append(epoch_params)
        if es:
            if not es.check_loss(model.model, valid):
                print("Stopping model now")
                model.model.load_state_dict(torch.load("checkpoint.pth"))
                break
    decoder_structure = True
    if model.params["dataset_params"]["class"] == "AutoEncoder":
        decoder_structure = False
    test = compute_validation(
        test_data_loader,
        model.model,
        epoch,
        model.params["dataset_params"]["forecast_length"],
        model.crit,
        model.device,
        meta_model=meta_model,
        multi_targets=num_targets,
        decoder_structure=decoder_structure,
        use_wandb=use_wandb,
        val_or_test="test_loss",
        probabilistic=probabilistic)
    print("test loss:", test)
    model.params["run"] = session_params
    model.save_model(model_filepath, max_epochs)
Exemplo n.º 4
0
def train_transformer_style(model: PyTorchForecast,
                            training_params: Dict,
                            takes_target=False,
                            forward_params: Dict = {},
                            model_filepath: str = "model_save") -> None:
    """
    Function to train any PyTorchForecast model
    :model The initialized PyTorchForecastModel
    :training_params_dict A dictionary of the parameters needed to train model
    :takes_target boolean: Determines whether to pass target during training
    :forward_params: A dictionary for additional forward parameters (for instance target)
    """
    use_wandb = model.wandb
    es = None
    if "early_stopping" in model.params:
        es = EarlyStopper(model.params["early_stopping"]['patience'])
    opt = pytorch_opt_dict[training_params["optimizer"]](
        model.model.parameters(), **training_params["optim_params"])
    criterion_init_params = {}
    if "criterion_params" in training_params:
        criterion_init_params = training_params["criterion_params"]
    criterion = pytorch_criterion_dict[training_params["criterion"]](
        **criterion_init_params)
    max_epochs = training_params["epochs"]
    data_loader = DataLoader(model.training,
                             batch_size=training_params["batch_size"],
                             shuffle=False,
                             sampler=None,
                             batch_sampler=None,
                             num_workers=0,
                             collate_fn=None,
                             pin_memory=False,
                             drop_last=False,
                             timeout=0,
                             worker_init_fn=None)
    validation_data_loader = DataLoader(
        model.validation,
        batch_size=training_params["batch_size"],
        shuffle=False,
        sampler=None,
        batch_sampler=None,
        num_workers=0,
        collate_fn=None,
        pin_memory=False,
        drop_last=False,
        timeout=0,
        worker_init_fn=None)
    test_data_loader = DataLoader(model.test_data,
                                  batch_size=1,
                                  shuffle=False,
                                  sampler=None,
                                  batch_sampler=None,
                                  num_workers=0,
                                  collate_fn=None,
                                  pin_memory=False,
                                  drop_last=False,
                                  timeout=0,
                                  worker_init_fn=None)
    if use_wandb:
        import wandb
        wandb.watch(model.model)
    session_params = []
    for epoch in range(max_epochs):
        total_loss = torch_single_train(model, opt, criterion, data_loader,
                                        takes_target, forward_params)
        print("The loss for epoch " + str(epoch))
        print(total_loss)
        use_decoder = False
        if "use_decoder" in model.params:
            use_decoder = True
        valid = compute_validation(
            validation_data_loader,
            model.model,
            epoch,
            model.params["dataset_params"]["forecast_length"],
            criterion,
            model.device,
            decoder_structure=use_decoder,
            use_wandb=use_wandb)
        if valid < 0.01:
            raise (
                "Error validation loss is zero there is a problem with the validator."
            )
        if use_wandb:
            wandb.log({'epoch': epoch, 'loss': total_loss})
        epoch_params = {
            "epoch": epoch,
            "train_loss": str(total_loss),
            "validation_loss": str(valid)
        }
        session_params.append(epoch_params)
        if es:
            if not es.check_loss(model.model, valid):
                print("Stopping model now")
                model.model.load_state_dict(torch.load("checkpoint.pth"))
                break
    decoder_structure = True
    if model.params["dataset_params"]["class"] != "default":
        decoder_structure = False
    test = compute_validation(
        test_data_loader,
        model.model,
        epoch,
        model.params["dataset_params"]["forecast_length"],
        criterion,
        model.device,
        decoder_structure=decoder_structure,
        use_wandb=use_wandb,
        val_or_test="test_loss")
    print("test loss:", test)
    model.params["run"] = session_params
    model.save_model(model_filepath, max_epochs)
Exemplo n.º 5
0
 def test_early_stopper2(self):
     e = EarlyStopper(3, .2)
     n = NaiveBase(2, 2)
     e.check_loss(n, .9)
     e.check_loss(n, .7)
     self.assertTrue(e.check_loss(n, .6))
def train_transformer_style(
        model: PyTorchForecast,
        training_params: Dict,
        takes_target=False,
        forward_params: Dict = {},
        model_filepath: str = "model_save") -> None:
    """
    Function to train any PyTorchForecast model
    :model The initialized PyTorchForecastModel
    :training_params_dict A dictionary of the parameters needed to train model
    :takes_target boolean: Determines whether to pass target during training
    :forward_params: A dictionary for additional forward parameters (for instance target)
    """
    use_wandb = model.wandb
    es = None
    worker_num = 1
    pin_memory = False
    dataset_params = model.params["dataset_params"]
    if "num_workers" in dataset_params:
        worker_num = dataset_params["num_workers"]
    if "pin_memory" in dataset_params:
        pin_memory = dataset_params["pin_memory"]
        print("Pin memory set to true")
    if "early_stopping" in model.params:
        es = EarlyStopper(model.params["early_stopping"]['patience'])
    opt = pytorch_opt_dict[training_params["optimizer"]](
        model.model.parameters(), **training_params["optim_params"])
    criterion_init_params = {}
    if "criterion_params" in training_params:
        criterion_init_params = training_params["criterion_params"]
    criterion = pytorch_criterion_dict[training_params["criterion"]](**criterion_init_params)
    if "probabilistic" in training_params:
        probabilistic = True
    else:
        probabilistic = False
    max_epochs = training_params["epochs"]
    data_loader = DataLoader(
        model.training,
        batch_size=training_params["batch_size"],
        shuffle=False,
        sampler=None,
        batch_sampler=None,
        num_workers=worker_num,
        collate_fn=None,
        pin_memory=pin_memory,
        drop_last=False,
        timeout=0,
        worker_init_fn=None)
    validation_data_loader = DataLoader(
        model.validation,
        batch_size=training_params["batch_size"],
        shuffle=False,
        sampler=None,
        batch_sampler=None,
        num_workers=worker_num,
        collate_fn=None,
        pin_memory=pin_memory,
        drop_last=False,
        timeout=0,
        worker_init_fn=None)
    test_data_loader = DataLoader(model.test_data, batch_size=1, shuffle=False, sampler=None,
                                  batch_sampler=None, num_workers=worker_num, collate_fn=None,
                                  pin_memory=pin_memory, drop_last=False, timeout=0,
                                  worker_init_fn=None)
    meta_model = None
    meta_representation = None
    if model.params.get("meta_data") is None:
        model.params["meta_data"] = False
    if model.params["meta_data"]:
        with open(model.params["meta_data"]["path"]) as f:
            json_data = json.load(f)
        dataset_params2 = json_data["dataset_params"]
        training_path = dataset_params2["training_path"]
        valid_path = dataset_params2["validation_path"]
        meta_name = json_data["model_name"]
        meta_model = PyTorchForecast(meta_name, training_path, valid_path, dataset_params2["test_path"], json_data)
        meta_representation = get_meta_representation(model.params["meta_data"]["column_id"],
                                                      model.params["meta_data"]["uuid"], meta_model)
    if use_wandb:
        wandb.watch(model.model)
    session_params = []
    for epoch in range(max_epochs):
        total_loss = torch_single_train(
            model,
            opt,
            criterion,
            data_loader,
            takes_target,
            meta_model,
            meta_representation,
            forward_params.copy())
        print("The loss for epoch " + str(epoch))
        print(total_loss)
        use_decoder = False
        if "use_decoder" in model.params:
            use_decoder = True
        valid = compute_validation(
            validation_data_loader,
            model.model,
            epoch,
            model.params["dataset_params"]["forecast_length"],
            criterion,
            model.device,
            meta_model=meta_model,
            decoder_structure=use_decoder,
            use_wandb=use_wandb,
            probabilistic=probabilistic)
        if valid == 0.0:
            raise ValueError("Error validation loss is zero there is a problem with the validator.")
        if use_wandb:
            wandb.log({'epoch': epoch, 'loss': total_loss})
        epoch_params = {
            "epoch": epoch,
            "train_loss": str(total_loss),
            "validation_loss": str(valid)}
        session_params.append(epoch_params)
        if es:
            if not es.check_loss(model.model, valid):
                print("Stopping model now")
                model.model.load_state_dict(torch.load("checkpoint.pth"))
                break
    decoder_structure = True
    if model.params["dataset_params"]["class"] != "default":
        decoder_structure = False
    test = compute_validation(
        test_data_loader,
        model.model,
        epoch,
        model.params["dataset_params"]["forecast_length"],
        criterion,
        model.device,
        meta_model=meta_model,
        decoder_structure=decoder_structure,
        use_wandb=use_wandb,
        val_or_test="test_loss",
        probabilistic=probabilistic)
    print("test loss:", test)
    model.params["run"] = session_params
    model.save_model(model_filepath, max_epochs)