Beispiel #1
0
    def test_qrnn(self, backend):
        """
        Test training of QRNNs using numpy arrays as input.
        """
        set_backend(backend)
        qrnn = QRNN(self.x_train.shape[1], np.linspace(0.05, 0.95, 10))
        qrnn.train((self.x_train, self.y_train), maximum_epochs=1)

        qrnn.predict(self.x_train)

        x, qs = qrnn.cdf(self.x_train[:2, :])
        assert qs[0] == 0.0
        assert qs[-1] == 1.0

        x, y = qrnn.pdf(self.x_train[:2, :])
        assert x.shape == y.shape

        mu = qrnn.posterior_mean(self.x_train[:2, :])
        assert len(mu.shape) == 1

        r = qrnn.sample_posterior(self.x_train[:4, :], n=2)
        assert r.shape == (4, 2)

        r = qrnn.sample_posterior_gaussian_fit(self.x_train[:4, :], n=2)
        assert r.shape == (4, 2)
Beispiel #2
0
 def test_qrnn_datasets(self, backend):
     """
     Provide data as dataset object instead of numpy arrays.
     """
     set_backend(backend)
     backend = get_backend(backend)
     data = backend.BatchedDataset((self.x_train, self.y_train), 256)
     qrnn = QRNN(self.x_train.shape[1], np.linspace(0.05, 0.95, 10))
     qrnn.train(data, maximum_epochs=1)
Beispiel #3
0
    def test_save_qrnn(self, backend):
        """
        Test saving and loading of QRNNs.
        """
        set_backend(backend)
        qrnn = QRNN(self.x_train.shape[1], np.linspace(0.05, 0.95, 10))
        f = tempfile.NamedTemporaryFile()
        qrnn.save(f.name)
        qrnn_loaded = QRNN.load(f.name)

        x_pred = qrnn.predict(self.x_train)
        x_pred_loaded = qrnn.predict(self.x_train)

        if not type(x_pred) == np.ndarray:
            x_pred = x_pred.detach()

        assert np.allclose(x_pred, x_pred_loaded)
Beispiel #4
0
# -*- coding: utf-8 -*-
"""
Created on Sat Sep 12 22:09:27 2020

@author: inderpreet

PLot the uncertainties in randomly chosen cases
"""
import matplotlib.pyplot as plt
import numpy as np
import netCDF4

from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,
                               AutoMinorLocator)
from typhon.retrieval.qrnn import set_backend, QRNN
set_backend("pytorch")
import stats as S
from ici import iciData
from calibration import calibration
import random
plt.rcParams.update({'font.size': 26})
from matplotlib import colors as mcolors
colors = dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS)


#%% input parameters
depth     = 4
width     = 128
quantiles = np.array([0.002, 0.03, 0.16, 0.5, 0.84, 0.97, 0.998])
batchSize = 128