def test_decoder_multi_step(self):
     if "save_path" in self.model_params:
         del self.model_params["save_path"]
     forecast_model = train_function("PyTorch", self.model_params)
     t = torch.Tensor([3, 4, 5]).repeat(1, 336, 1)
     output = simple_decode(forecast_model.model,
                            torch.ones(1, 5, 3),
                            336,
                            t,
                            output_len=3)
     # We want to check for leakage
     self.assertFalse(3 in output[:, :, 0])
 def test_multivariate_single_step(self):
     # dumb error fixes
     if "save_path" in self.model_params3:
         del self.model_params["save_path"]
     t = torch.Tensor([3, 6, 5]).repeat(1, 100, 1)
     forecast_model3 = train_function("PyTorch", self.model_params3)
     output = simple_decode(forecast_model3.model,
                            torch.ones(1, 5, 3),
                            100,
                            t,
                            output_len=3,
                            multi_targets=2)
     self.assertFalse(3 in output)
     self.assertFalse(6 in output)
Exemplo n.º 3
0
def compute_validation(validation_loader: DataLoader,
                       model,
                       epoch: int,
                       sequence_size: int,
                       criterion: Type[torch.nn.modules.loss._Loss],
                       device: torch.device,
                       decoder_structure=False,
                       meta_data_model=None,
                       use_wandb: bool = False,
                       meta_model=None,
                       multi_targets=1,
                       val_or_test="validation_loss",
                       probabilistic=False) -> float:
    """Function to compute the validation loss metrics

    :param validation_loader: The data-loader of either validation or test-data
    :type validation_loader: DataLoader
    :param model: model
    :type model: [type]
    :param epoch: The epoch where the validation/test loss is being computed.
    :type epoch: int
    :param sequence_size: [description]
    :type sequence_size: int
    :param criterion: [description]
    :type criterion: Type[torch.nn.modules.loss._Loss]
    :param device: The device
    :type device: torch.device
    :param decoder_structure: Whether the model should use sequential decoding, defaults to False
    :type decoder_structure: bool, optional
    :param meta_data_model: The model to handle the meta-data, defaults to None
    :type meta_data_model: PyTorchForecast, optional
    :param use_wandb: Whether Weights and Biases is in use, defaults to False
    :type use_wandb: bool, optional
    :param meta_model: Whether the model leverages meta-data, defaults to None
    :type meta_model: bool, optional
    :param multi_targets: Whether the model, defaults to 1
    :type multi_targets: int, optional
    :param val_or_test: Whether validation or test loss is computed, defaults to "validation_loss"
    :type val_or_test: str, optional
    :param probabilistic: Whether the model is probablistic, defaults to False
    :type probabilistic: bool, optional
    :return: The loss of the first metric in the list.
    :rtype: float
    """
    print('Computing validation loss')
    unscaled_crit = dict.fromkeys(criterion, 0)
    scaled_crit = dict.fromkeys(criterion, 0)
    model.eval()
    output_std = None
    multi_targs1 = multi_targets
    scaler = None
    if validation_loader.dataset.no_scale:
        scaler = validation_loader.dataset
    with torch.no_grad():
        i = 0
        loss_unscaled_full = 0.0
        for src, targ in validation_loader:
            src = src if isinstance(src, list) else src.to(device)
            targ = targ if isinstance(targ, list) else targ.to(device)
            # targ = targ if isinstance(targ, list) else targ.to(device)
            i += 1
            if decoder_structure:
                if type(model).__name__ == "SimpleTransformer":
                    targ_clone = targ.detach().clone()
                    output = greedy_decode(model,
                                           src,
                                           targ.shape[1],
                                           targ_clone,
                                           device=device)[:, :, 0]
                elif type(model).__name__ == "Informer":
                    multi_targets = multi_targs1
                    filled_targ = targ[1].clone()
                    pred_len = model.pred_len
                    filled_targ[:, -pred_len:, :] = torch.zeros_like(
                        filled_targ[:, -pred_len:, :]).float().to(device)
                    output = model(src[0].to(device), src[1].to(device),
                                   filled_targ.to(device), targ[0].to(device))
                    labels = targ[1][:, -pred_len:, 0:multi_targets]
                    src = src[0]
                    multi_targets = False
                else:
                    output = simple_decode(model=model,
                                           src=src,
                                           max_seq_len=targ.shape[1],
                                           real_target=targ,
                                           output_len=sequence_size,
                                           multi_targets=multi_targets,
                                           probabilistic=probabilistic,
                                           scaler=scaler)
                    if probabilistic:
                        output, output_std = output[0], output[1]
                        output, output_std = output[:, :, 0], output_std[0]
                        output_dist = torch.distributions.Normal(
                            output, output_std)
            else:
                if probabilistic:
                    output_dist = model(src.float())
                    output = output_dist.mean.detach().numpy()
                    output_std = output_dist.stddev.detach().numpy()
                else:
                    output = model(src.float())
            if multi_targets == 1:
                labels = targ[:, :, 0]
            elif multi_targets > 1:
                labels = targ[:, :, 0:multi_targets]
            validation_dataset = validation_loader.dataset
            for crit in criterion:
                if validation_dataset.scale:
                    # Should this also do loss.item() stuff?
                    if len(src.shape) == 2:
                        src = src.unsqueeze(0)
                    src1 = src[:, :, 0:multi_targets]
                    loss_unscaled_full = compute_loss(labels,
                                                      output,
                                                      src1,
                                                      crit,
                                                      validation_dataset,
                                                      probabilistic,
                                                      output_std,
                                                      m=multi_targets)
                    unscaled_crit[crit] += loss_unscaled_full.item() * len(
                        labels.float())
                loss = compute_loss(labels,
                                    output,
                                    src,
                                    crit,
                                    False,
                                    probabilistic,
                                    output_std,
                                    m=multi_targets)
                scaled_crit[crit] += loss.item() * len(labels.float())
    if use_wandb:
        if loss_unscaled_full:
            scaled = {
                k.__class__.__name__: v / (len(validation_loader.dataset) - 1)
                for k, v in scaled_crit.items()
            }
            newD = {
                k.__class__.__name__: v / (len(validation_loader.dataset) - 1)
                for k, v in unscaled_crit.items()
            }
            wandb.log({
                'epoch': epoch,
                val_or_test: scaled,
                "unscaled_" + val_or_test: newD
            })
        else:
            scaled = {
                k.__class__.__name__: v / (len(validation_loader.dataset) - 1)
                for k, v in scaled_crit.items()
            }
            wandb.log({'epoch': epoch, val_or_test: scaled})
    model.train()
    return list(scaled_crit.values())[0]
Exemplo n.º 4
0
def compute_validation(validation_loader: DataLoader,
                       model,
                       epoch: int,
                       sequence_size: int,
                       criterion: Type[torch.nn.modules.loss._Loss],
                       device: torch.device,
                       decoder_structure=False,
                       use_wandb: bool = False,
                       val_or_test="validation_loss") -> float:
    """
    Function to compute the validation or test loss
    """
    model.eval()
    loop_loss = 0.0
    with torch.no_grad():
        i = 0
        loss_unscaled_full = 0.0
        for src, targ in validation_loader:
            src = src.to(device)
            targ = targ.to(device)
            i += 1
            if decoder_structure:
                if type(model).__name__ == "SimpleTransformer":
                    targ_clone = targ.detach().clone()
                    output = greedy_decode(model,
                                           src,
                                           targ.shape[1],
                                           targ_clone,
                                           device=device)[:, :, 0]
                else:
                    output = simple_decode(model, src, targ.shape[1], targ,
                                           1)[:, :, 0]
            else:
                output = model(src.float())
            labels = targ[:, :, 0]
            validation_dataset = validation_loader.dataset
            if validation_dataset.scale:
                # unscaled_src = validation_dataset.scale.inverse_transform(src.cpu())
                unscaled_out = validation_dataset.inverse_scale(output.cpu())
                unscaled_labels = validation_dataset.inverse_scale(
                    labels.cpu())
                loss_unscaled = criterion(unscaled_out,
                                          unscaled_labels.float())
                loss_unscaled_full += len(
                    labels.float()) * loss_unscaled.item()
                if i % 10 == 0 and use_wandb:
                    import wandb
                    wandb.log({
                        "trg": unscaled_labels,
                        "model_pred": unscaled_out
                    })
            loss = criterion(output, labels.float())
            loop_loss += len(labels.float()) * loss.item()
    if use_wandb:
        import wandb
        if loss_unscaled_full:
            tot_unscaled_loss = loss_unscaled_full / (
                len(validation_loader.dataset) - 1)
            wandb.log({
                'epoch':
                epoch,
                val_or_test:
                loop_loss / (len(validation_loader.dataset) - 1),
                "unscaled_" + val_or_test:
                tot_unscaled_loss
            })
        else:
            wandb.log({
                'epoch':
                epoch,
                val_or_test:
                loop_loss / (len(validation_loader.dataset) - 1)
            })
    model.train()
    return loop_loss / (len(validation_loader.dataset) - 1)
Exemplo n.º 5
0
def compute_validation(validation_loader: DataLoader,
                       model,
                       epoch: int,
                       sequence_size: int,
                       criterion: Type[torch.nn.modules.loss._Loss],
                       device: torch.device,
                       decoder_structure=False,
                       meta_data_model=None,
                       use_wandb: bool = False,
                       meta_model=None,
                       val_or_test="validation_loss",
                       probabilistic=False) -> float:
    """
    Function to compute the validation or the test loss
    """
    print('compute_validation')
    model.eval()
    loop_loss = 0.0
    output_std = None
    with torch.no_grad():
        i = 0
        loss_unscaled_full = 0.0
        for src, targ in validation_loader:
            src = src.to(device)
            targ = targ.to(device)
            i += 1
            if decoder_structure:
                if type(model).__name__ == "SimpleTransformer":
                    targ_clone = targ.detach().clone()
                    output = greedy_decode(model,
                                           src,
                                           targ.shape[1],
                                           targ_clone,
                                           device=device)[:, :, 0]
                else:
                    if probabilistic:
                        output, output_std = simple_decode(
                            model,
                            src,
                            targ.shape[1],
                            targ,
                            1,
                            probabilistic=probabilistic)
                        output, output_std = output[:, :, 0], output_std[0]
                        output_dist = torch.distributions.Normal(
                            output, output_std)
                    else:
                        output = simple_decode(
                            model=model,
                            src=src,
                            max_seq_len=targ.shape[1],
                            real_target=targ,
                            output_len=1,
                            probabilistic=probabilistic)[:, :, 0]
            else:
                if probabilistic:
                    output_dist = model(src.float())
                    output = output_dist.mean.detach().numpy()
                    output_std = output_dist.stddev.detach().numpy()
                else:
                    output = model(src.float())
            labels = targ[:, :, 0]
            validation_dataset = validation_loader.dataset
            if validation_dataset.scale:
                loss_unscaled_full += compute_loss(labels, output, src,
                                                   criterion,
                                                   validation_dataset,
                                                   probabilistic, output_std)
            loss = compute_loss(labels, output, src, criterion, False,
                                probabilistic, output_std)
            loop_loss += len(labels.float()) * loss.item()
    if use_wandb:
        if loss_unscaled_full:
            tot_unscaled_loss = loss_unscaled_full / (
                len(validation_loader.dataset) - 1)
            wandb.log({
                'epoch':
                epoch,
                val_or_test:
                loop_loss / (len(validation_loader.dataset) - 1),
                "unscaled_" + val_or_test:
                tot_unscaled_loss
            })
        else:
            wandb.log({
                'epoch':
                epoch,
                val_or_test:
                loop_loss / (len(validation_loader.dataset) - 1)
            })
    model.train()
    return loop_loss / (len(validation_loader.dataset) - 1)
Exemplo n.º 6
0
def compute_validation(validation_loader: DataLoader,  # s lint
                       model,
                       epoch: int,
                       sequence_size: int,
                       criterion: Type[torch.nn.modules.loss._Loss],
                       device: torch.device,
                       decoder_structure=False,
                       meta_data_model=None,
                       use_wandb: bool = False,
                       meta_model=None,
                       val_or_test="validation_loss",
                       probabilistic=False) -> float:
    """
    Function to compute the validation or the test loss
    """
    print('compute_validation')
    model.eval()
    loop_loss = 0.0
    with torch.no_grad():
        i = 0
        loss_unscaled_full = 0.0
        for src, targ in validation_loader:
            src = src.to(device)
            targ = targ.to(device)
            i += 1
            if decoder_structure:
                if type(model).__name__ == "SimpleTransformer":
                    targ_clone = targ.detach().clone()
                    output = greedy_decode(
                        model,
                        src,
                        targ.shape[1],
                        targ_clone,
                        device=device)[
                        :,
                        :,
                        0]
                else:
                    if probabilistic:
                        output, output_std = simple_decode(model,
                                                           src,
                                                           targ.shape[1],
                                                           targ,
                                                           1,
                                                           probabilistic=probabilistic)
                        output, output_std = output[:, :, 0], output_std[0]
                        output_dist = torch.distributions.Normal(output, output_std)
                    else:
                        output = simple_decode(model=model,
                                               src=src,
                                               max_seq_len=targ.shape[1],
                                               real_target=targ,
                                               output_len=1,
                                               probabilistic=probabilistic)[:, :, 0]
            else:
                if probabilistic:
                    output_dist = model(src.float())
                    output = output_dist.mean.detach().numpy()
                    output_std = output_dist.stddev.detach().numpy()
                else:
                    output = model(src.float())
            labels = targ[:, :, 0]
            validation_dataset = validation_loader.dataset
            if validation_dataset.scale:
                unscaled_labels = validation_dataset.inverse_scale(labels)
                if probabilistic:
                    unscaled_out = validation_dataset.inverse_scale(output)
                    try:
                        output_std = numpy_to_tvar(output_std)
                    except Exception:
                        pass
                    unscaled_dist = torch.distributions.Normal(unscaled_out, output_std)
                    loss_unscaled = -unscaled_dist.log_prob(unscaled_labels.float()).sum()  # FIX THIS
                    loss_unscaled_full += len(labels.float()) * loss_unscaled.numpy().item()
                else:
                    # unscaled_src = validation_dataset.scale.inverse_transform(src.cpu())
                    unscaled_out = validation_dataset.inverse_scale(output.cpu())
                    unscaled_labels = validation_dataset.inverse_scale(labels.cpu())
                    loss_unscaled = criterion(unscaled_out, unscaled_labels.float())
                    loss_unscaled_full += len(labels.float()) * loss_unscaled.item()
                if i % 10 == 0 and use_wandb:
                    wandb.log({"trg": unscaled_labels, "model_pred": unscaled_out})
            if probabilistic:
                loss = -output_dist.log_prob(labels.float()).sum()  # FIX THIS
                loss = loss.numpy()
            elif isinstance(criterion, GaussianLoss):
                g_loss = GaussianLoss(output[0], output[1])
                loss = g_loss(labels)
            else:
                loss = criterion(output, labels.float())
            loop_loss += len(labels.float()) * loss.item()
    if use_wandb:
        if loss_unscaled_full:
            tot_unscaled_loss = loss_unscaled_full / (len(validation_loader.dataset) - 1)
            wandb.log({'epoch': epoch,
                       val_or_test: loop_loss / (len(validation_loader.dataset) - 1),
                       "unscaled_" + val_or_test: tot_unscaled_loss})
        else:
            wandb.log({'epoch': epoch, val_or_test: loop_loss /
                       (len(validation_loader.dataset) - 1)})
    model.train()
    return loop_loss / (len(validation_loader.dataset) - 1)
Exemplo n.º 7
0
def compute_validation(validation_loader: DataLoader,
                       model,
                       epoch: int,
                       sequence_size: int,
                       criterion: Type[torch.nn.modules.loss._Loss],
                       device: torch.device,
                       decoder_structure=False,
                       meta_data_model=None,
                       use_wandb: bool = False,
                       meta_model=None,
                       multi_targets=1,
                       val_or_test="validation_loss",
                       probabilistic=False) -> float:
    """Function to compute the validation loss metrics

    :param validation_loader: The data-loader of either validation or test-data
    :type validation_loader: DataLoader
    :param model: model
    :type model: [type]
    :param epoch: [description]
    :type epoch: int
    :param sequence_size: [description]
    :type sequence_size: int
    :param criterion: [description]
    :type criterion: Type[torch.nn.modules.loss._Loss]
    :param device: [description]
    :type device: torch.device
    :param decoder_structure: [description], defaults to False
    :type decoder_structure: bool, optional
    :param meta_data_model: [description], defaults to None
    :type meta_data_model: [type], optional
    :param use_wandb: [description], defaults to False
    :type use_wandb: bool, optional
    :param meta_model: [description], defaults to None
    :type meta_model: [type], optional
    :param multi_targets: [description], defaults to 1
    :type multi_targets: int, optional
    :param val_or_test: [description], defaults to "validation_loss"
    :type val_or_test: str, optional
    :param probabilistic: [description], defaults to False
    :type probabilistic: bool, optional
    :return: [description]
    :rtype: float
    """
    print('compute_validation')
    unscaled_crit = dict.fromkeys(criterion, 0)
    scaled_crit = dict.fromkeys(criterion, 0)
    model.eval()
    output_std = None
    scaler = None
    if validation_loader.dataset.no_scale:
        scaler = validation_loader.dataset
    with torch.no_grad():
        i = 0
        loss_unscaled_full = 0.0
        for src, targ in validation_loader:
            src = src.to(device)
            targ = targ.to(device)
            i += 1
            if decoder_structure:
                if type(model).__name__ == "SimpleTransformer":
                    targ_clone = targ.detach().clone()
                    output = greedy_decode(
                        model,
                        src,
                        targ.shape[1],
                        targ_clone,
                        device=device)[
                        :,
                        :,
                        0]
                else:
                    if probabilistic:
                        output, output_std = simple_decode(model,
                                                           src,
                                                           targ.shape[1],
                                                           targ,
                                                           1,
                                                           multi_targets=multi_targets,
                                                           probabilistic=probabilistic,
                                                           scaler=scaler)
                        output, output_std = output[:, :, 0], output_std[0]
                        output_dist = torch.distributions.Normal(output, output_std)
                    else:
                        output = simple_decode(model=model,
                                               src=src,
                                               max_seq_len=targ.shape[1],
                                               real_target=targ,
                                               output_len=sequence_size,
                                               multi_targets=multi_targets,
                                               probabilistic=probabilistic,
                                               scaler=scaler)[:, :, 0:multi_targets]
            else:
                if probabilistic:
                    output_dist = model(src.float())
                    output = output_dist.mean.detach().numpy()
                    output_std = output_dist.stddev.detach().numpy()
                else:
                    output = model(src.float())
            if multi_targets == 1:
                labels = targ[:, :, 0]
            elif multi_targets > 1:
                labels = targ[:, :, 0:multi_targets]
            validation_dataset = validation_loader.dataset
            for crit in criterion:
                if validation_dataset.scale:
                    # Should this also do loss.item() stuff?
                    if len(src.shape) == 2:
                        src = src.unsqueeze(0)
                    src1 = src[:, :, 0:multi_targets]
                    loss_unscaled_full = compute_loss(labels, output, src1, crit, validation_dataset,
                                                      probabilistic, output_std, m=multi_targets)
                    unscaled_crit[crit] += loss_unscaled_full.item() * len(labels.float())
                loss = compute_loss(labels, output, src, crit, False, probabilistic, output_std, m=multi_targets)
                scaled_crit[crit] += loss.item() * len(labels.float())
    if use_wandb:
        if loss_unscaled_full:
            scaled = {k.__class__.__name__: v / (len(validation_loader.dataset) - 1) for k, v in scaled_crit.items()}
            newD = {k.__class__.__name__: v / (len(validation_loader.dataset) - 1) for k, v in unscaled_crit.items()}
            wandb.log({'epoch': epoch,
                       val_or_test: scaled,
                       "unscaled_" + val_or_test: newD})
        else:
            scaled = {k.__class__.__name__: v / (len(validation_loader.dataset) - 1) for k, v in scaled_crit.items()}
            wandb.log({'epoch': epoch, val_or_test: scaled})
    model.train()
    return list(scaled_crit.values())[0]