train_elbo.append(-total_epoch_loss_train)
    print("[epoch %d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))

    if epoch % TEST_FREQUENCY == 0:
        vae.eval()
        total_epoch_loss_test = evaluate(svi, test_loader)
        vae.train()
        test_elbo.append(-total_epoch_loss_test)
        print("[epoch %d] average test loss: %.4f" % (epoch, total_epoch_loss_test))

        # Save stuff
        if (total_epoch_loss_test < best):
          print('SAVING EPOCH', epoch)
          best = total_epoch_loss_test
          pyro.get_param_store().save('drive/My Drive/pyro_weights.save')
          optimizer.save('drive/My Drive/optimizer_state.save')
          checkpoint = {'model_state_dict': vae.state_dict()}
          torch.save(checkpoint, 'drive/My Drive/torch_weights.save')

        i = 0
        fig = plt.figure()
        fig.add_subplot(2, 2, 1)
        plt.imshow(test_loader.dataset[i][:3].permute(1, 2, 0)*test_loader.dataset.dataset.std[:3] + test_loader.dataset.dataset.mean[:3])
        fig.add_subplot(2, 2, 2)
        plt.imshow(test_loader.dataset[i][3]*test_loader.dataset.dataset.std[3] + test_loader.dataset.dataset.mean[3])

        test_input = test_loader.dataset[i].unsqueeze(0)#.cuda()
        reconstructed = vae.reconstruct(test_input).cpu().detach()[0]
        fig.add_subplot(2, 2, 3)
        plt.imshow(reconstructed[:3].permute(1, 2, 0)*test_loader.dataset.dataset.std[:3] + test_loader.dataset.dataset.mean[:3])
        fig.add_subplot(2, 2, 4)
class BayesianNeuralNetworkRegression:
    """Bayesian Neural Network that uses a Pyro model to predict multiple targets

    Uses Pyros Elbo Loss internally
    Args:
        batch_size (int,optional): Otherwise training set is split into batches of given size. Default: None
        shuffle (bool,optional): Set to True to have the data reshuffled at every epoch. Default: False
        learning_rate (float,optional): Learning rate for optimizer. Default: 1e-3
        use_gpu (bool,optional):  Flag that allows usage of cuda cores for calculations. Default: False
        patience (int,optional): Stop training after p continous incrementations. Default: None
        training_limit (int,optional): After specified number of epochs training will be terminated, regardless of early stopping. Default: 100
        verbosity (int,optional): 0 to only print errors, 1 (default) to print status information. Default: 1
        print_after_epochs (int,optional): Specifies after how many epochs training and validation loss will be printed to command line. Default: 500
    """
    def __init__(self,
                 batch_size=None,
                 shuffle=False,
                 learning_rate=1e-3,
                 use_gpu=False,
                 patience=None,
                 training_limit=100,
                 verbosity=1,
                 print_after_epochs=500):
        self.patience = patience
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.learning_rate = learning_rate
        self.training_limit = training_limit
        self.print_after_epochs = print_after_epochs
        self.verbosity = verbosity
        self.Device = 'cpu'
        if use_gpu is True and torch.cuda.is_available():
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
            self.Device = "cuda:0"

        if training_limit is None and patience is None:
            raise ValueError('Either training_limit or patience must be set')

    # FIXME: Training behaviour when Patience and traininglimit is set

    def fit(self, X_train, y_train):
        """Fits the model to the training data set

        Args:
            X_train (nd.array): Set of descriptive Variables
            y_train (nd.array): Set of target Variables

        Returns:
            NeuralNetRegressor: fitted NeuralNetRegressor
        """
        X_train, X_validate_t, y_train, y_validate_t = train_test_split(
            X_train, y_train, test_size=0.1)
        X_train_t = torch.tensor(X_train, dtype=torch.float).to(self.Device)
        y_train_t = torch.tensor(y_train, dtype=torch.float).to(self.Device)
        X_validate_t = torch.tensor(X_validate_t,
                                    dtype=torch.float).to(self.Device)
        y_validate_t = torch.tensor(y_validate_t,
                                    dtype=torch.float).to(self.Device)

        n_targets = len(y_train_t[0])
        n_features = len(X_train_t[0])
        self.net = NN(n_features, n_targets)
        self.net.to(self.Device)
        self.guide = AutoDiagonalNormal(self.model)
        self.optim = Adam({"lr": self.learning_rate})
        self.svi = SVI(self.model, self.guide, self.optim, loss=Trace_ELBO())

        batch_size = len(
            X_train_t) if self.batch_size is None else self.batch_size
        train_dataloader = DataLoader(TensorDataset(X_train_t, y_train_t),
                                      batch_size=batch_size,
                                      shuffle=self.shuffle)
        pyro.clear_param_store()
        losses = []
        if self.patience is not None:
            stopper = EarlyStopping(self.patience)
        stop = False
        epochs = 0
        while (stop is False):
            # calculate the loss and take a gradient step
            for batch in train_dataloader:
                batch_X = batch[0]
                batch_y = batch[1]
                loss_batch = self.svi.step(batch_X, batch_y)
                losses.append(loss_batch)

            validation_loss = self.svi.evaluate_loss(X_validate_t,
                                                     y_validate_t)
            train_loss = self.svi.evaluate_loss(X_train_t, y_train_t)
            if self.patience is not None:
                stop = stopper.stop(validation_loss, self.net)
            # if stop is True and self.patience > 1:
            #     # TODO: add loading of best,guide, optimizer and model here
            #     self.net = stopper.best_model
            #     self.svi.
            if epochs % self.print_after_epochs == 0:
                printMessage(
                    'Epoch: {}\nValidation Loss: {} \nTrain Loss: {}'.format(
                        epochs, validation_loss, train_loss), self.verbosity)

            epochs += 1

            if self.training_limit is not None and self.training_limit <= epochs:
                stop = True

        final_train_loss = self.svi.evaluate_loss(X_train_t, y_train_t)
        final_validation_loss = self.svi.evaluate_loss(X_validate_t,
                                                       y_validate_t)
        printMessage(
            "Final Epochs: {} \nFinal Train Loss: {}\nFinal Validation Loss: {}"
            .format(epochs, final_train_loss,
                    final_validation_loss), self.verbosity)
        return self

    def predict(self, X_test, num_samples=100):
        """Predicts the target variables for the given test set
        Args:
            X_test (np.ndarray): Test set withdescriptive variables
        Returns:
            np.ndarray: Predicted target variables
        """
        from pyro.infer import Predictive
        x_data_test = torch.tensor(X_test, dtype=torch.float).to(self.Device)

        predictive = Predictive(self.net,
                                guide=self.guide,
                                num_samples=100,
                                return_sites=("obs", "_RETURN"))

        samples = predictive(x_data_test)
        pred_summary = self.__summary(samples)
        stds = pred_summary['_RETURN']['std']
        means = pred_summary['_RETURN']['mean']

        return stds.cpu().detach().numpy(), means.cpu().detach().numpy()

    def __summary(self, samples):
        site_stats = {}
        for k, v in samples.items():
            site_stats[k] = {
                "mean": torch.mean(v, 0),
                "std": torch.std(v, 0),
                "5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0],
                "95%": v.kthvalue(int(len(v) * 0.95), dim=0)[0],
            }
        return site_stats

    def save(self, path):
        """Save model and store it at given path

        Args:
            store_path (string): Path to store model at
        """
        model_path = path + '_model'
        opt_path = path + '_opt'
        guide_path = path + '_guide'
        torch.save(self.net, model_path)
        torch.save(self.guide, guide_path)

        self.optim.save(opt_path)
        ps = pyro.get_param_store()
        ps.save(path + '_params')

    def load(self, path):
        """Load model from path

        Args:
            load_path (string): Path to saved model
        """
        model_path = path + '_model'
        opt_path = path + '_opt'
        guide_path = path + '_guide'
        self.net = torch.load(model_path)

        pyro.get_param_store().load(path + '_params')

        self.optim = Adam({"lr": self.learning_rate})
        self.optim.load(opt_path)
        self.guide = AutoDiagonalNormal(self.model)
        self.guide = torch.load(guide_path)
        self.svi = SVI(self.net, self.guide, self.optim, loss=Trace_ELBO())

    def model(self, x_data, y_data):
        fc1w_prior = Normal(
            loc=torch.zeros_like(self.net.fc1.weight).to(self.Device),
            scale=torch.ones_like(self.net.fc1.weight).to(self.Device))
        fc1b_prior = Normal(
            loc=torch.zeros_like(self.net.fc1.bias).to(self.Device),
            scale=torch.ones_like(self.net.fc1.bias).to(self.Device))

        fc2w_prior = Normal(
            loc=torch.zeros_like(self.net.fc2.weight).to(self.Device),
            scale=torch.ones_like(self.net.fc2.weight).to(self.Device))
        fc2b_prior = Normal(
            loc=torch.zeros_like(self.net.fc2.bias).to(self.Device),
            scale=torch.ones_like(self.net.fc2.bias).to(self.Device))

        outw_prior = Normal(
            loc=torch.zeros_like(self.net.out.weight).to(self.Device),
            scale=torch.ones_like(self.net.out.weight).to(self.Device))
        outb_prior = Normal(
            loc=torch.zeros_like(self.net.out.bias).to(self.Device),
            scale=torch.ones_like(self.net.out.bias).to(self.Device))
        priors = {
            'fc1.weight': fc1w_prior,
            'fc1.bias': fc1b_prior,
            'fc2.weight': fc2w_prior,
            'fc2.bias': fc2b_prior,
            'out.weight': outw_prior,
            'out.bias': outb_prior
        }
        # lift module parameters to random variables sampled from the priors
        lifted_module = pyro.random_module("module", self.net, priors)
        # sample a regressor (which also samples w and b)
        lifted_reg_model = lifted_module()
        scale = pyro.sample("sigma", Uniform(0., 10.))
        # with pyro.plate("map", len(x_data)):
        # run the nn forward on data
        prediction_mean = lifted_reg_model(x_data).squeeze(-1)
        # condition on the observed data
        with pyro.iarange("observed data", use_cuda=True):
            pyro.sample("obs", Normal(prediction_mean, scale), obs=y_data)
        return prediction_mean

    def score(self, X_test, y_test):
        """Returns Average Relative Root Mean Squared Error for given test data and targets

        Args:
            X_test (np.ndarray): Test samples
            y_test (np.ndarray): True targets
        """
        # return means
        y_pred = self.predict(X_test)[1]
        return average_relative_root_mean_squared_error(y_pred, y_test)

    # FOLLOWING FUNCTIONS ARE NECESSARY TO PERFORM GRID SEARCH

    def _get_param_names(cls):
        """Get parameter names for the estimator"""
        # fetch the constructor or the original constructor before
        # deprecation wrapping if any
        init = getattr(cls.__init__, 'deprecated_original', cls.__init__)
        if init is object.__init__:
            # No explicit constructor to introspect
            return []

        # introspect the constructor arguments to find the model parameters
        # to represent
        init_signature = inspect.signature(init)
        # Consider the constructor parameters excluding 'self'
        parameters = [
            p for p in init_signature.parameters.values()
            if p.name != 'self' and p.kind != p.VAR_KEYWORD
        ]
        for p in parameters:
            if p.kind == p.VAR_POSITIONAL:
                raise RuntimeError("scikit-learn estimators should always "
                                   "specify their parameters in the signature"
                                   " of their __init__ (no varargs)."
                                   " %s with constructor %s doesn't "
                                   " follow this convention." %
                                   (cls, init_signature))
        # Extract and sort argument names excluding 'self'
        return sorted([p.name for p in parameters])

    def get_params(self, deep=True):
        """Get parameters for this estimator.
        Parameters
        ----------
        deep : boolean, optional
            If True, will return the parameters for this estimator and
            contained subobjects that are estimators.
        Returns
        -------
        params : mapping of string to any
            Parameter names mapped to their values.
        """
        out = dict()
        for key in self._get_param_names():
            value = getattr(self, key, None)
            if deep and hasattr(value, 'get_params'):
                deep_items = value.get_params().items()
                out.update((key + '__' + k, val) for k, val in deep_items)
            out[key] = value
        return out

    def set_params(self, **params):
        """Set the parameters of this estimator.
        The method works on simple estimators as well as on nested objects
        (such as pipelines). The latter have parameters of the form
        ``<component>__<parameter>`` so that it's possible to update each
        component of a nested object.
        Returns
        -------
        self
        """
        if not params:
            # Simple optimization to gain speed (inspect is slow)
            return self
        valid_params = self.get_params(deep=True)

        nested_params = defaultdict(dict)  # grouped by prefix
        for key, value in params.items():
            key, delim, sub_key = key.partition('__')
            if key not in valid_params:
                raise ValueError('Invalid parameter %s for estimator %s. '
                                 'Check the list of available parameters '
                                 'with `estimator.get_params().keys()`.' %
                                 (key, self))

            if delim:
                nested_params[key][sub_key] = value
            else:
                setattr(self, key, value)
                valid_params[key] = value

        for key, sub_params in nested_params.items():
            valid_params[key].set_params(**sub_params)

        return self