Beispiel #1
0
    def _model():
        p = yield Root(tfd.Beta(dtype(1), dtype(1), name="p"))
        gamma_C = yield Root(tfd.Beta(dtype(1), dtype(1), name="gamma_C"))
        gamma_T = yield Root(tfd.Beta(dtype(1), dtype(1), name="gamma_T"))
        eta_C = yield Root(tfd.Dirichlet(np.ones(K, dtype=dtype) / K,
                                         name="eta_C"))
        eta_T = yield Root(tfd.Dirichlet(np.ones(K, dtype=dtype) / K,
                                         name="eta_T"))
        loc = yield Root(tfd.Sample(tfd.Normal(dtype(0), dtype(1)),
                                    sample_shape=K, name="loc"))
        nu = yield Root(tfd.Sample(tfd.Uniform(dtype(10), dtype(50)),
                                   sample_shape=K, name="nu"))
        phi = yield Root(tfd.Sample(tfd.Normal(dtype(m_phi), dtype(s_phi)),
                                    sample_shape=K, name="phi"))
        sigma_sq = yield Root(tfd.Sample(tfd.InverseGamma(dtype(3), dtype(2)),
                                         sample_shape=K,
                              name="sigma_sq"))
        scale = np.sqrt(sigma_sq)

        gamma_T_star = compute_gamma_T_star(gamma_C, gamma_T, p)
        eta_T_star = compute_eta_T_star(gamma_C[..., tf.newaxis],
                                        gamma_T[..., tf.newaxis],
                                        eta_C, eta_T,
                                        p[..., tf.newaxis],
                                        gamma_T_star[..., tf.newaxis])

        # likelihood
        y_C = yield mix(nC, eta_C, loc, scale, name="y_C")
        n0C = yield tfd.Binomial(nC, gamma_C, name="n0C")
        y_T = yield mix(nT, eta_T_star, loc, scale, name="y_T")
        n0T = yield tfd.Binomial(nT, gamma_T_star, name="n0T")
Beispiel #2
0
def create_prior(K,
                 a_p=1,
                 b_p=1,
                 a_gamma=1,
                 b_gamma=1,
                 m_loc=0,
                 g_loc=0.1,
                 m_sigma=3,
                 s_sigma=2,
                 m_nu=0,
                 s_nu=1,
                 m_skew=0,
                 g_skew=0.1,
                 dtype=np.float64):
    return tfd.JointDistributionNamed(
        dict(
            p=tfd.Beta(dtype(a_p), dtype(b_p)),
            gamma_C=tfd.Gamma(dtype(a_gamma), dtype(b_gamma)),
            gamma_T=tfd.Gamma(dtype(a_gamma), dtype(b_gamma)),
            eta_C=tfd.Dirichlet(tf.ones(K, dtype=dtype) / K),
            eta_T=tfd.Dirichlet(tf.ones(K, dtype=dtype) / K),
            nu=tfd.Sample(tfd.LogNormal(dtype(m_nu), s_nu), sample_shape=K),
            sigma_sq=tfd.Sample(tfd.InverseGamma(dtype(m_sigma),
                                                 dtype(s_sigma)),
                                sample_shape=K),
            loc=lambda sigma_sq: tfd.Independent(tfd.Normal(
                dtype(m_loc), g_loc * tf.sqrt(sigma_sq)),
                                                 reinterpreted_batch_ndims=1),
            skew=lambda sigma_sq: tfd.Independent(tfd.Normal(
                dtype(m_skew), g_skew * tf.sqrt(sigma_sq)),
                                                  reinterpreted_batch_ndims=1),
        ))
Beispiel #3
0
def mix(gamma, eta, loc, scale, neg_inf, n):
    return tfd.Mixture(
        cat=tfd.Categorical(probs=tf.stack([gamma, 1 - gamma], axis=-1)),
        components=[
            tfd.Sample(tfd.Normal(np.float64(neg_inf), 1e-5), sample_shape=n),
            tfd.Sample(tfd.MixtureSameFamily(
                mixture_distribution=tfd.Categorical(probs=eta),
                components_distribution=tfd.Normal(loc=loc, scale=scale)),
                       sample_shape=n)
        ])
Beispiel #4
0
def create_model(n_C, n_T, K, neg_inf=-10, dtype=np.float64):
    return tfd.JointDistributionNamed(
        dict(p=tfd.Beta(dtype(1), dtype(1)),
             gamma_C=tfd.Gamma(dtype(3), dtype(3)),
             gamma_T=tfd.Gamma(dtype(3), dtype(3)),
             eta_C=tfd.Dirichlet(tf.ones(K, dtype=dtype) / K),
             eta_T=tfd.Dirichlet(tf.ones(K, dtype=dtype) / K),
             loc=tfd.Sample(tfd.Normal(dtype(0), dtype(1)), sample_shape=K),
             sigma_sq=tfd.Sample(tfd.InverseGamma(dtype(3), dtype(2)),
                                 sample_shape=K),
             y_C=lambda gamma_C, eta_C, loc, sigma_sq: mix(
                 gamma_C, eta_C, loc, tf.sqrt(sigma_sq), dtype(neg_inf), n_C),
             y_T=lambda gamma_C, gamma_T, eta_C, eta_T, p, loc, sigma_sq:
             mix_T(gamma_C, gamma_T, eta_C, eta_T, p, loc, tf.sqrt(sigma_sq),
                   dtype(neg_inf), n_T)))
Beispiel #5
0
 def __init__(
     self,
     name: Optional[NameType],
     *,
     transform=None,
     observed=None,
     batch_stack=None,
     event_stack=None,
     conditionally_independent=False,
     reinterpreted_batch_ndims=0,
     **kwargs,
 ):
     self.conditions = self.unpack_conditions(**kwargs)
     self._distribution = self._init_distribution(self.conditions)
     super().__init__(
         self.unpack_distribution, name=name, keep_return=True, keep_auxiliary=False
     )
     if name is None and observed is not None:
         raise ValueError(
             "Observed variables are not allowed for anonymous (with name=None) Distributions"
         )
     self.model_info.update(observed=observed)
     self.transform = self._init_transform(transform)
     self.batch_stack = batch_stack
     self.event_stack = event_stack
     self.conditionally_independent = conditionally_independent
     self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
     if reinterpreted_batch_ndims:
         self._distribution = tfd.Independent(
             self._distribution, reinterpreted_batch_ndims=reinterpreted_batch_ndims
         )
     if batch_stack is not None:
         self._distribution = BatchStacker(self._distribution, batch_stack=batch_stack)
     if event_stack is not None:
         self._distribution = tfd.Sample(self._distribution, sample_shape=self.event_stack)
Beispiel #6
0
def mix(n, eta, loc, scale, name):
    return tfd.Sample(
        tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(probs=eta),
            components_distribution=tfd.Normal(loc=loc, scale=scale),
            name=name),
        sample_shape=n)
Beispiel #7
0
def create_dp_sb_gmm(nobs, K, dtype=np.float64):
    return tfd.JointDistributionNamed(
        dict(
            # Mixture means
            mu=tfd.Independent(tfd.Normal(np.zeros(K, dtype), 3),
                               reinterpreted_batch_ndims=1),
            # Mixture scales
            sigma=tfd.Independent(tfd.LogNormal(loc=np.full(K, -2, dtype),
                                                scale=0.5),
                                  reinterpreted_batch_ndims=1),
            # Mixture weights (stick-breaking construction)
            alpha=tfd.Gamma(concentration=np.float64(1.0), rate=10.0),
            v=lambda alpha: tfd.Independent(
                # NOTE: Dave Moore suggests doing this instead, to ensure
                # that a batch dimension in alpha doesn't conflict with
                # the other parameters.
                tfd.Beta(np.ones(K - 1, dtype), alpha[..., tf.newaxis]),
                reinterpreted_batch_ndims=1),
            # Observations (likelihood)
            obs=lambda mu, sigma, v: tfd.Sample(
                tfd.MixtureSameFamily(
                    # This will be marginalized over.
                    mixture_distribution=tfd.Categorical(probs=stickbreak(v)),
                    components_distribution=tfd.Normal(mu, sigma)),
                sample_shape=nobs)))
Beispiel #8
0
    def __init__(self, blocks, hdLayers, latent_dim, label_dim, layerSize,
                 input_dim, clipval, permutation, zero_noise, y_noise,
                 latent_perturb_scale, **kwargs):
        '''
        Input:
            blocks      <int> defines the amount of blocks that the INN should contain
            hdLayers    <int> defines the amount of hidden layers per sub network
            latent_dim  <int> defines the dimension of the latent space
            layerSize   <int> defines the amount of neurons in each hidden layer
            input_dim   <int> defines the dimension of the input
            clipval     <float> defines the maximum argument in the exponential function
            permutation <int array> defines a permutation to shuffle data at the end of a block
            zero_noise  <float> scale of Gaussian distribution used to fill the padding dimensions
            y_noise     <float> scale of Gaussian distribution used to perturb the labels.
            latent_perturb_scale <float> defines how much the latent variable will be perturbed in the call function
        
        This class contains a whole INN with multiple blocks.
        '''
        super(INN, self).__init__(**kwargs)

        #initilizing variables
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.label_dim = label_dim
        self.latent_perturb_scale = latent_perturb_scale

        #initilizing the random distributions that are needed
        latent_mean = tf.zeros(2)
        latent_scale = tf.ones(2)
        zero_mean = tf.zeros(self.input_dim - 2)
        label_mean = tf.zeros(self.label_dim)

        self.s = tfd.Sample(tfd.Normal(loc=latent_mean, scale=latent_scale))
        self.s2 = tfd.Sample(tfd.Normal(loc=zero_mean, scale=zero_noise))
        self.s3 = tfd.Sample(tfd.Normal(loc=label_mean, scale=y_noise))

        #filling a list with the INNBlocks for the network
        self.blocks = []
        for curBlock in range(blocks - 1):
            self.blocks.append(
                INNBlock(hdLayers, layerSize, input_dim, clipval, permutation))

        self.blocks.append(
            INNBlockLast(hdLayers, layerSize, input_dim, clipval))
Beispiel #9
0
 def __init__(self,
              name: Optional[NameType],
              *,
              transform=None,
              observed=None,
              plate=None,
              **kwargs):
     self.conditions = self.unpack_conditions(**kwargs)
     self._distribution = self._init_distribution(self.conditions)
     self.plate = plate
     super().__init__(self.unpack_distribution,
                      name=name,
                      keep_return=True,
                      keep_auxiliary=False)
     if name is None and observed is not None:
         raise ValueError(
             "Observed variables are not allowed for anonymous (with name=None) Distributions"
         )
     self.model_info.update(observed=observed)
     self.transform = self._init_transform(transform)
     if self.plate is not None:
         self._distribution = tfd.Sample(self._distribution,
                                         sample_shape=self.plate)
Beispiel #10
0
    return tf.linalg.cholesky(K)


def compute_f(alpha, rho, beta, eta):
    LK = compute_LK(alpha, rho, X)
    f = tf.linalg.matvec(LK, eta)  # LK * eta, (matrix * vector)
    return f + beta[..., tf.newaxis]


# GP Binary Classification Model.
gpc_model = tfd.JointDistributionNamed(
    dict(
        alpha=tfd.LogNormal(dtype(0), dtype(1)),
        rho=tfd.LogNormal(dtype(0), dtype(1)),
        beta=tfd.Normal(dtype(0), dtype(1)),
        eta=tfd.Sample(tfd.Normal(dtype(0), dtype(1)),
                       sample_shape=X.shape[0]),
        # NOTE: `Sample` and `Independent` resemble, respectively,
        # `filldist` and `arraydist` in Turing.
        obs=lambda alpha, rho, beta, eta: tfd.Independent(
            tfd.Bernoulli(logits=compute_f(alpha, rho, beta, eta)),
            reinterpreted_batch_ndims=1)))

### MODEL SET UP ###

# For some reason, this is needed for the compiler
# to know the correct model parameter dimensions.
_ = gpc_model.sample()

# Parameters as they appear in model definition.
# NOTE: Initial values should be defined in order appeared in model.
ordered_params = ['alpha', 'rho', 'beta', 'eta']
Beispiel #11
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        if self.plate is not None:
            self._backend_distribution = tfd.Sample(self._backend_distribution,
                                                    sample_shape=self.plate)
def eightGauss(save=False):
    '''
    This function trains an INN on a Gaussain mixture model and then shows how well the network is able to generate data similar to the training data.
    Option:
        save <bool> if True the figures will be saved
    '''

    numb_of_dists = 8  # number of distributions in the gaussian mixture model
    size = 1000  # number of datapoints per distribution

    zero_noise = 0.08  # scale of the Gaussian distribution used to fill the padding dimensions
    y_noise = 0.105  # scale of the Gaussian distribution used to perturb the label dimensions

    input_dim = 16  # input dimension
    latent_dim = 2  # latent dimension
    label_dim = 4  # label dimension

    # A Callback function is can be called during training to set the training parameters dynamically.
    class MyCallback(tf.keras.callbacks.Callback):
        def __init__(self, alpha, beta):
            '''
            Input:
                alpha, beta <Keras variable> placeholder for a loss weight
            '''

            self.alpha = alpha
            self.beta = beta
            self.i = 0.0  # counts the number of training steps completed

        def on_train_begin(self, logs={}):
            '''This function is called each time the training starts.'''
            self.i += 1.0

            #set values of alpha and beta
            self.alpha = K.set_value(
                self.alpha, 20 *
                np.asarray([1., 2. * 0.003**(1.0 - (self.i / 30.0))]).min())
            self.beta = K.set_value(self.beta,
                                    np.asarray([self.i * 3, 21]).min())

    def genData(s, input_dim, latent_dim, label_dim):
        '''
        Input:
            s           <tensorflow_probability distribution object> contains the distribution used to sample the latent variable.
            input_dim   <int> input dimension
            latent_dim  <int> latent dimension
            label_dim   <int> label dimension
        Output:
            x           <2D numpy array> contains x and y coorinates of the Gaussian mixture model
            y           <2D numpy array> contains the perturbed labels 
            z           <2D numpy array> contains the latent variables
            y_clean     <2D numpy array> contains the unperturb version of the labels
        
        This function generates all the data needed to train the INN
        '''

        #get x and y coordinates as well as labels for a gaussian mixture model.
        ga = gaussian.mixture(numb_of_dists, size=size)

        #shuffle the data
        np.random.shuffle(ga)

        #store x and y coordinates
        x = np.delete(ga, 2, axis=1)

        #sample a set of latent variables
        z = tf.dtypes.cast(s.sample(numb_of_dists * size),
                           tf.dtypes.float32).numpy()

        #get the labels
        _, _, y_load = ga.T

        #keep the labels in a custom form
        y_clean = np.zeros((numb_of_dists * size, label_dim))
        for j in range(len(y_clean)):
            if y_load[j] < 4:
                y_clean[j][0] += 1.0
            else:
                if y_load[j] < 6:
                    y_clean[j][1] += 1.0
                else:
                    y_clean[j][int(y_load[j] - 4)] += 1.0

        #perturb the labels
        y = y_clean + np.random.normal(
            loc=0.0, scale=y_noise, size=(len(y_load), label_dim))

        return x, y, z, y_clean

    #creating an INN
    inn = INN(3, 3, latent_dim, label_dim, 384, input_dim, 2.0,
              [7, 12, 5, 15, 2, 14, 6, 10, 8, 3, 1, 11, 13, 9, 4, 0],
              zero_noise, y_noise, 0.05)

    #loading an already trained INN
    # inn.load_weights("./network/8Gauss.tf")

    #intilizing the callback variables
    alpha = K.variable(0.0)
    beta = K.variable(0.0)

    #compiling the model.
    inn.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001,
                                                   beta_1=0.5,
                                                   beta_2=0.5,
                                                   amsgrad=False,
                                                   clipvalue=15.0),
                loss=[losses.MMD_forward, 'mse', 'mse', losses.MMD_backward],
                loss_weights=[18.0, 0.08, 0.08, beta],
                metrics=['accuracy'])

    #creating distribution objects used to sample the latent variable (s) and padding dimension (s2)
    latent_mean = np.zeros(2)
    zero_mean = np.zeros(input_dim - 2)

    s = tfd.Sample(tfd.Normal(loc=latent_mean, scale=1.0))
    s2 = tfd.Sample(tfd.Normal(loc=zero_mean, scale=zero_noise))

    #training loop
    for i in range(30):

        x_pad = s2.sample((numb_of_dists * size)).numpy()

        x, y, z, y_clean = genData(s, input_dim, latent_dim, label_dim)

        #concatenate the data to their final form used in the training process
        x_t = np.concatenate((x, x_pad, y_clean), axis=1)
        y_t = np.concatenate([
            y,
            np.random.normal(loc=0.0,
                             scale=zero_noise,
                             size=(len(x), input_dim - label_dim - latent_dim))
        ],
                             axis=1)
        z_t = np.concatenate([z, y], axis=1)

        #training the network
        inn.fit(x_t, [z_t, y_t, x_t[:, :input_dim], x],
                epochs=1,
                batch_size=200,
                verbose=2,
                callbacks=[MyCallback(alpha, beta)])  #,callbacks=[tensorboard]

    #saving the network weights
    # inn.save_weights('./network/8Gauss.tf')

    #generating new data to plot the results
    x_pad = s2.sample((numb_of_dists * size)).numpy().astype(np.float32)
    x, y, z, _ = genData(s, input_dim, latent_dim, label_dim)

    x_t = np.concatenate((x, x_pad), axis=1)
    y_t = np.concatenate([
        y,
        np.random.normal(loc=0.0,
                         scale=zero_noise,
                         size=(len(x), input_dim - label_dim - latent_dim))
    ],
                         axis=1)
    z_t = np.concatenate([z, y_t], axis=1)

    show_size = 2000

    #using the generated data to perform an inverse pass through the network
    out = inn.inv(z_t[:show_size])

    #splitting the output into x and y coordinates
    x_coord, y_coord = out[0].numpy().T

    fig = py.figure(figsize=(11, 11))
    fig.subplots_adjust(left=0.1, right=0.9, bottom=0.2, top=0.8)

    ax = fig.add_subplot(111)
    ax.set_xlabel("X", fontsize=25)
    ax.set_ylabel("Y", fontsize=25)
    ax.tick_params(axis='both', which='major', labelsize=24)

    #///////////////////////////////////
    # ax.set_title("Generated Data",fontsize=25)
    ax.set_ylim(bottom=-3.1, top=3.1)
    ax.set_xlim(left=-3.1, right=3.1)
    #//////////////////////////////
    ax.scatter(x_coord,
               y_coord,
               cmap='Spectral',
               c=np.argmax(y[:show_size], axis=1),
               s=2.0)
    if save:
        fig.savefig('8gaussGen.png', dpi=300)

    fig = py.figure(figsize=(11, 11))
    fig.subplots_adjust(left=0.1, right=0.9, bottom=0.2, top=0.8)

    ax = fig.add_subplot(111)
    ax.set_xlabel("X", fontsize=25)
    ax.set_ylabel("Y", fontsize=25)
    ax.tick_params(axis='both', which='major', labelsize=24)
    ax.set_ylim(bottom=-3.1, top=3.1)
    ax.set_xlim(left=-3.1, right=3.1)
    ax.scatter(x[:show_size, 0],
               x[:show_size, 1],
               cmap='Spectral',
               c=np.argmax(y[:show_size], axis=1),
               s=2.0)
    if save:
        fig.savefig('8gaussOrig.png', dpi=300)

    py.show()
Beispiel #13
0
def polar1(save=False):

    input_dim = 2
    inn = INN(64, input_dim, -1.0, 1.0)
    inn.load_weights("./network/polar1.tf")

    inn.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001,
                                                   beta_1=0.99,
                                                   beta_2=0.999,
                                                   amsgrad=False),
                loss='mse',
                metrics=['accuracy'])

    loc1 = np.asarray([3.0, 0.0])
    loc2 = np.zeros(6)
    s = tfd.Sample(tfd.Uniform(low=[0.1, -5.0], high=[5.0, 5.0]))
    s2 = tfd.Sample(tfd.Normal(loc=loc2, scale=0.05))

    for i in range(0):
        x, y = s.sample(10000).numpy().T
        x_t = np.asarray([x, y]).T
        r, phi = polar(x, y)
        y_t = np.asarray([r, phi]).T
        inn.fit(x_t, y_t, epochs=2, batch_size=256, verbose=2)

    x, y = np.random.normal(4.0, 1.0, (600)), np.random.normal(0.0, 1.0, (600))
    x_t = np.asarray([x, y]).T
    r, phi = polar(x, y)

    x_t = np.asarray([x, y]).T  #network input
    rp, phip = inn.predict(x_t).T
    y_t = np.asarray([r, phi]).T  #inverse network input
    xp, yp = inn.inv(y_t).numpy().T

    # print(np.mean(np.abs(r-rr)))
    # print(np.mean(np.abs(phi-rphi)))

    # for polar1forward------------------------------------------------------------
    fig = py.figure(figsize=(11, 9))
    fig.subplots_adjust(left=0.2, right=0.95, bottom=0.2, top=0.95)

    ax = fig.add_subplot(111)
    ax.plot(rp * np.cos(phip), rp * np.sin(phip), 'ro', label='transformed')
    ax.plot(x, y, 'go', label='original')

    ax.set_xlabel("X", fontsize=30)
    ax.set_ylabel("Y", fontsize=30)
    ax.tick_params(axis='both', which='major', labelsize=29)
    ax.legend(fontsize=25)
    if save:
        fig.savefig('polar1forward.png', dpi=300)

    # for polar1backward - ------------------------------------------------------    fig = py.figure(figsize=(10,9))
    fig = py.figure(figsize=(11, 9))
    fig.subplots_adjust(left=0.2, right=0.95, bottom=0.2, top=0.95)

    ax = fig.add_subplot(111)
    ax.plot(xp, yp, 'ro', label='inverse transformed')
    ax.plot(x, y, 'go', label='original')

    ax.set_xlabel("X", fontsize=30)
    ax.set_ylabel("Y", fontsize=30)
    ax.tick_params(axis='both', which='major', labelsize=29)
    ax.legend(fontsize=25)
    if save:
        fig.savefig('polar1backward.png', dpi=300)
    # py.show()

    # for polar1mesh-----------------------------------------
    # rr, rphi = polar(xmesh,ymesh)
    # out = inn.predict(mesh)
    # r,phi = out.T
    # y_t = np.asarray([rr,rphi]).T
    # out = inn.inv(y_t).numpy()
    # xr,yr = out.T

    # fig = py.figure(figsize=(15,15))
    # fig.subplots_adjust(left=0.1,right = 0.9,bottom=0.2,top=0.8)

    # ax = fig.add_subplot(111)
    # ax.plot(r*np.cos(phi),r*np.sin(phi),'bo',label ='transformed')
    # ax.plot(xr,yr,'ro',label='inverse transform')
    # ax.plot(xmesh,ymesh,'go',label='original',alpha=0.5)

    # ax.set_xlabel("X",fontsize=25)
    # ax.set_ylabel("Y",fontsize=25)
    # ax.tick_params(axis='both', which='major', labelsize=24)
    # ax.legend(fontsize = 20)
    # if save:
    # fig.savefig('polar1mesh.png',dpi=300)

    py.show()

    # inn.save_weights('./network/polar1.tf')
    return None
Beispiel #14
0
def twoGauss(save = False):
    # This function trains an INN on a Gaussain mixture model and then shows how well the network is able to generate data similar to the training data.
    # Option:
    #   save <bool> if True the figures will be saved
    
    
    numb_of_dists = 2   # number of distributions in the gaussian mixture model
    size = 1000         # number of datapoints per distribution
    
    zero_noise=0.09    # scale of the Gaussian distribution used to fill the padding dimensions
    y_noise = 0.1       # scale of the Gaussian distribution used to perturb the label dimensions
    
    input_dim = 8       # input dimension
    latent_dim = 2      # latent dimension
    label_dim = 0       # label dimension
    
    # A Callback function is can be called during training to set the training parameters dynamically.
    class MyCallback(tf.keras.callbacks.Callback):
        def __init__(self, alpha, beta):
            # Input:
            #    alpha, beta <Keras variable> placeholder for a loss weight
            self.alpha = alpha 
            self.beta = beta
            self.i = 0.0 # counts the number of training steps completed
        
        def on_train_begin(self, logs={}):
            #is called each time the training starts.
            self.i += 1.0
            
            #set values of alpha and beta
            self.alpha =K.set_value(self.alpha, 20* np.asarray([1., 2. * 0.003**(1.0 - (self.i / 30.0))]).min())
            self.beta = K.set_value(self.beta,np.asarray([self.i*1.5,16]).min())
    
    def genData(s,input_dim,latent_dim):
        # Input:
        #   s           <tensorflow_probability distribution object> contains the distribution used to sample the latent variable.
        #   input_dim   <int> input dimension
        #   latent_dim  <int> latent dimension
        # Output:
        #   x           <2D numpy array> contains x and y coorinates of the Gaussian mixture model
        #   z           <2D numpy array> contains the latent variables
        #
        # This function generates all the data needed to train the INN
        
        #get x and y coordinates as well as labels for a gaussian mixture model.
        ga = gaussian.mixture(numb_of_dists,size=size)
        
        #shuffle the data
        np.random.shuffle(ga)
        
        #store x and y coordinates
        x = np.delete(ga,2,axis=1)
        
        #sample a set of latent variables
        z = tf.dtypes.cast(s.sample(numb_of_dists*size),tf.dtypes.float32).numpy()

        return x,z
    
    #creating an INN
    inn = INN(4,5,latent_dim,label_dim,256,input_dim,2.0,[1,2,7,5,4,6,3,0],zero_noise,y_noise,0.1)
    
    #loading an already trained INN
    # inn.load_weights("./network/2Gauss.tf")

    #intilizing the callback variables
    alpha = K.variable(0.0)
    beta = K.variable(0.0)
    
    #compiling the model.
    inn.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.9, beta_2=0.9,amsgrad=False,clipvalue=15.0),loss=[losses.MMD_forward,'mse','mse',losses.MMD_backward],loss_weights=[100.0,4.0,4.0,100.0],metrics=['accuracy'])
    
    #creating distribution objects used to sample the latent variable (s) and padding dimension (s2)
    latent_mean = np.zeros(2)
    zero_mean = np.zeros(input_dim -2) 

    s = tfd.Sample(tfd.Normal(loc=latent_mean, scale=1.0))
    s2 = tfd.Sample(tfd.Normal(loc=zero_mean, scale=zero_noise))
    
    
    #training loop
    for i in range(60):
       
        #generate data
        x_pad = s2.sample((numb_of_dists*size)).numpy()
        x,z = genData(s,input_dim,latent_dim)
        
        #concatenate the data to their final form used in the training process
        x_t = np.concatenate((x,x_pad),axis=1)         
        y_t = np.concatenate([np.random.normal(loc=0.0,scale=zero_noise,size=(len(x),input_dim - label_dim - latent_dim))],axis=1)
        
        #training the network
        inn.fit(x_t,[z,y_t,x_t[:,:input_dim],x],epochs=2,batch_size=2000,verbose=2,callbacks=[MyCallback(alpha,beta)]) #,callbacks=[tensorboard]
    
    #saving the network weights
    # inn.save_weights('./network/2Gauss.tf')
    
    
    #generating new data to plot the results
    x_pad = s2.sample((numb_of_dists*size)).numpy().astype(np.float32)
    x,z = genData(s,input_dim,latent_dim)
    #z= np.random.uniform(low=-1.0,high=1.0,size=(2000,2))
    
    x_t = np.concatenate((x,x_pad),axis=1)
    y_t = np.concatenate([np.random.normal(loc=0.0,scale=zero_noise,size=(len(x),input_dim - label_dim - latent_dim))],axis=1)
    z_t = np.concatenate([z,y_t],axis=1)
    
    
    show_size = 2000 
    
    #using the generated data to perform an inverse pass through the network
    out = inn.inv(z_t[:show_size])  

    #splitting the output into x and y coordinates
    x_coord,y_coord = out[0].numpy().T
   
    fig=py.figure(figsize=(14,3.5))
    fig.subplots_adjust(left=0.1,right = 0.9,bottom=0.25,top=0.9)
    
    ax = fig.add_subplot(111)
    ax.set_xlabel("X",fontsize=25)
    ax.set_ylabel("Y",fontsize=25)
    ax.tick_params(axis='both', which='major', labelsize=24)
    
    #///////////////////////////////////
    # ax.set_title("Generated Data",fontsize=25)
    ax.set_ylim(bottom = -0.5,top =0.5)
    ax.set_xlim(left = -3.1,right=3.1)
    #//////////////////////////////
    ax.scatter(x_coord,y_coord,c='red',s=2.0)
    if save:
        fig.savefig('2gaussGen.png',dpi=300)
    
    fig=py.figure(figsize=(14,3.5))
    fig.subplots_adjust(left=0.1,right = 0.9,bottom=0.25,top=0.9)
    
    ax = fig.add_subplot(111)
    ax.set_xlabel("X",fontsize=25)
    ax.set_ylabel("Y",fontsize=25)
    ax.tick_params(axis='both', which='major', labelsize=24)
    ax.set_ylim(bottom = -0.5,top =0.5)
    ax.set_xlim(left = -3.1,right=3.1)
    ax.scatter(x[:show_size,0],x[:show_size,1],c='red',s=2.0)
    if save:
        fig.savefig('2gaussOrig.png',dpi=300)

    py.show()