示例#1
0
 def test_sampling(self):
     init = initilaze_topic_model()
     init.initilize()
     sampleman = Sampling(init.xcorpus, init.ycorpus)
     sampleman.sampling(init.TOPICS, init.xcounts, init.ycounts, init.docid, init.different_word)
     print sampleman.xcorpus
     print sampleman.ycorpus
示例#2
0
 def fisher_sample(self):
     if self.sampler is None:
         print("preprocessing")
         if self.dico_fisher is None:
             print('you need to compute the fisher information first')
             return
         self.sampler = Sampling(self.build_mean(), self.dico_fisher)
         print('sampling ok')
     config = self.network.get_config()
     if self.network.__class__.__name__=='Sequential':
         new_model = Sequential.from_config(config)
     else:
         new_model = Model.from_config(config)
     new_params = self.sampler.sample()
     """
     means = self.sampler.mean
     
     for key in means:
         if np.max(np.abs(means[key] - new_params[key]))==0:
             print key
     print('kikou')
     import pdb; pdb.set_trace()
     """
     #tmp_prob = self.sampler.prob(new_params)
     new_model.compile(loss=self.network.loss,
                       optimizer=str.lower(self.network.optimizer.__class__.__name__),
                       metrics = self.network.metrics)
     new_model.set_weights(self.network.get_weights())
     self.copy_weights(new_model, new_params)
     return new_model
示例#3
0
 def __init__(self, ):
     self.sampling_mpu_0x68 = Sampling(AK8963_ADDRESS, MPU9050_ADDRESS_68,
                                       MPU9050_ADDRESS_68, 1, GFS_1000,
                                       AFS_8G, AK8963_BIT_16,
                                       AK8963_MODE_C100HZ)
     self.sampling_mpu_0x69 = Sampling(AK8963_ADDRESS, MPU9050_ADDRESS_69,
                                       None, 1, GFS_1000, AFS_8G,
                                       AK8963_BIT_16, AK8963_MODE_C100HZ)
示例#4
0
 def test_sampling(self):
     init = initilaze_topic_model()
     init.initilize()
     sampleman = Sampling(init.xcorpus, init.ycorpus)
     sampleman.sampling(init.TOPICS, init.xcounts, init.ycounts, init.docid,
                        init.different_word)
     print sampleman.xcorpus
     print sampleman.ycorpus
 def test_uniformly_sampling(self, test_input, expected):
     batch_size = 4
     order_provider = Sampling.get_uniformly_sampled_order(
         test_input, 4, batch_size)
     order = next(order_provider)
     result = chain.from_iterable(
         [[test_input[ind] for ind in order[i:i + batch_size]]
          for i in range(0, len(test_input), batch_size)])
     assert all(a == b for a, b in zip(result, expected))
示例#6
0
 def test_order_iterator(self, test_input):
     order_provider = OrderProvider(
         Sampling.get_random_order(len(test_input)))
     order1 = next(order_provider)
     order2 = next(order_provider)
     assert all(a == b for a, b in zip(order1, order2))
     order_provider.update()
     order3 = next(order_provider)
     assert any(a != b for a, b in zip(order1, order3))
示例#7
0
    def __create_sequences(self, X, labels, sampling=True):
        sequences = self.tokenizer.texts_to_sequences(X)

        data = pad_sequences(sequences, padding='post',
                             maxlen=self.params['max_length'])

        indices = np.arange(data.shape[0])
        np.random.shuffle(indices)
        data = data[indices]
        labels = labels[indices]

        if sampling:
            sample = Sampling(2., .5)
            x_train, y_train = sample.perform_sampling(data, labels, [0, 1])
        else:
            x_train, y_train = data, labels

        return x_train, y_train
示例#8
0
    def __init__(self, data_infile, fasttext_model_path, triplet_margin=0.1):

        self.sampling = Sampling(data_infile, fasttext_model_path)

        self.amount_negative_names = 1
        self.triplet_margin = triplet_margin
        self.anchor_margin = 0

        self.loss_weights = {'synonym': 1, 'proto': 1}

        torch.autograd.set_detect_anomaly(True)
    def __init__(self, data_infile, fasttext_model_path, triplet_margin=0.1):

        self.sampling = Sampling(data_infile, fasttext_model_path)

        self.amount_negative_names = 1
        self.triplet_margin = triplet_margin
        self.anchor_margin = 0

        self.loss_weights = {
            'semantic_similarity': 1,
            'contextual': 1,
            'grounding': 1
        }

        torch.autograd.set_detect_anomaly(True)
示例#10
0
 def __init__(self, latent_dim, seed):
     super(encoder, self).__init__()
     np.random.seed(seed)
     self.layer_1 = Conv2D(
         filters=32,
         kernel_size=(4, 4),
         activation="relu",
         strides=2,
         padding="same",
         kernel_initializer=tf.keras.initializers.HeNormal(seed))
     self.layer_2 = Conv2D(
         filters=32,
         kernel_size=(4, 4),
         activation="relu",
         strides=2,
         padding="same",
         kernel_initializer=tf.keras.initializers.HeNormal(seed))
     self.layer_3 = Conv2D(
         filters=64,
         kernel_size=(4, 4),
         activation="relu",
         strides=2,
         padding="same",
         kernel_initializer=tf.keras.initializers.HeNormal(seed))
     self.layer_4 = Conv2D(
         filters=64,
         kernel_size=(4, 4),
         activation="relu",
         strides=2,
         padding="same",
         kernel_initializer=tf.keras.initializers.HeNormal(seed))
     self.layer_5 = Dense(
         units=128,
         activation='relu',
         kernel_initializer=tf.keras.initializers.HeNormal(seed))
     self.dense_log_var = Dense(
         units=latent_dim,
         kernel_initializer=tf.keras.initializers.HeNormal(seed))
     self.dense_mean = Dense(
         units=latent_dim,
         kernel_initializer=tf.keras.initializers.HeNormal(seed))
     self.sampling = Sampling()
     self.batch_norm_1 = BatchNormalization()
     self.batch_norm_2 = BatchNormalization()
     self.batch_norm_3 = BatchNormalization()
     self.batch_norm_4 = BatchNormalization()
     self.flatten = Flatten()
示例#11
0
tokenizer = Tokenizer(num_words=MAX_NB_WORDS)
tokenizer.fit_on_texts(X)
sequences = tokenizer.texts_to_sequences(X)
word_index = tokenizer.word_index

data = pad_sequences(sequences, padding='post', maxlen=MAX_SEQUENCE_LENGTH)
print('Shape of data tensor:', data.shape)
print('Shape of label tensor:', labels.shape)

indices = np.arange(data.shape[0])
np.random.shuffle(indices)
data = data[indices]
labels = labels[indices]

num_validation_samples = int(VALIDATION_SPLIT * data.shape[0])
sample = Sampling(2., .5)
x_train, y_train = sample.perform_sampling(data[:-num_validation_samples],
                                           labels[:-num_validation_samples],
                                           [0, 1])
x_val = data[-num_validation_samples:]
y_val = labels[-num_validation_samples:]
print('Number of entries in each category:')
print('training: ', y_train.sum(axis=0))
print('validation: ', y_val.sum(axis=0))

model = Word2Vec.load('1ft.modelFile')

embeddings_index = {}
embedding_matrix = np.random.random((len(word_index) + 1, EMBEDDING_DIM))
for word, i in word_index.items():
    embedding_vector = model.wv[word]
示例#12
0
#!/usr/bin/env python

import numpy as np
import fitsio
import scipy.optimize as optimize
from sampling import Sampling

sampling = Sampling(nsamples=1000)
sampling.set_flux(total_flux=1000., noise=0.001)


def mem_function(u, A, f, llambda):
    Ar = (A.dot(u) - f)
    As = (Ar**2).sum()
    Bs = (u * np.log(u)).sum()
    val = As + llambda * Bs
    grad = 2. * A.T.dot(Ar) + llambda * (1. + np.log(u))
    return (val, grad)


def mem_fit(sampling, llambda=1.e-2):
    S_M0 = np.ones(sampling.nx * sampling.ny)
    bounds = zip([1.e-5] * len(S_M0), [None] * len(S_M0))
    bounds = [x for x in bounds]
    flux = sampling.flux
    results = optimize.minimize(mem_function,
                                S_M0,
                                args=(sampling.A, flux, llambda),
                                method='L-BFGS-B',
                                jac=True,
                                bounds=bounds)
 def test_random_sampling(self, test_input, expected):
     order_provider = Sampling.get_random_order(test_input)
     order = next(order_provider)
     result = set(order)
     assert result == expected
def comp_sketch(matrix, objective, load_N=False, save_N=False, N_dir='../N_file/', **kwargs):
    """
    Given matrix A, the function comp_sketch computes a sketch for A and performs further operations on PA.
    It returns the total running time and the desired quantity.

    parameter:
        matrix: a RowMatrix object storing the matrix [A b]
        objective: either 'x' or 'N'
            'x': the function returns the solution to the problem min_x || PA[:,:-1]x - PA[:,-1] ||_2
            'N': the function returns a square matrix N such that PA[:,:-1]*inv(N) is a matrix with orthonormal columns
        load_N: load the precomputed N matrices if possible (it reduces the actual running time for sampling sketches)
        save_N: save the computed N matrices for future use
        sketch_type: either 'projection' or 'sampling'
        projection_type: cw, gaussian, rademacher or srdht
        c: projection size
        s: sampling size (for sampling sketch only)
        k: number of independent trials to run
    """

    sketch_type = kwargs.get('sketch_type')

    if not os.path.exists(N_dir):
        os.makedirs(N_dir)

    if objective == 'x':
        
        if sketch_type == 'projection':
            projection = Projections(**kwargs)
            t = time.time()
            x = projection.execute(matrix, 'x', save_N)
            t = time.time() - t

            if save_N:
                logger.info('Saving N matrices from projections!')
                N = [a[0] for a in x]
                x = [a[1] for a in x]
                # saving N
                filename = N_dir + 'N_' + matrix.name + '_projection_' + kwargs.get('projection_type') + '_c' + str(int(kwargs.get('c'))) + '_k' + str(int(kwargs.get('k')))+ '.dat'
                data = {'N': N, 'time': t}
                pickle_write(filename,data)
 
        elif sketch_type == 'sampling':
            s = kwargs.get('s')
            new_N_proj = 0
            N_proj_filename = N_dir + 'N_' + matrix.name + '_projection_' + kwargs.get('projection_type') + '_c' + str(int(kwargs.get('c'))) + '_k' + str(int(kwargs.get('k'))) +'.dat'

            if load_N and os.path.isfile(N_proj_filename):
                logger.info('Found N matrices from projections, loading them!')
                N_proj_filename = N_dir + 'N_' + matrix.name + '_projection_' + kwargs.get('projection_type') + '_c' + str(int(kwargs.get('c'))) + '_k' + str(int(kwargs.get('k'))) +'.dat'
                result = pickle_load(N_proj_filename)
                N_proj = result['N']
                t_proj = result['time']
            else: # otherwise, compute it
                t = time.time()
                projection = Projections(**kwargs)
                N_proj = projection.execute(matrix, 'N')
                t_proj = time.time() - t
                new_N_proj = 1

            sampling = Sampling(N=N_proj)
            t = time.time()
            x = sampling.execute(matrix, 'x', s, save_N )
            t = time.time() - t + t_proj

            if save_N and new_N_proj:
                logger.info('Saving N matrices from projections!')
                #filename = N_dir + 'N_' + matrix.name + '_projection_' + kwargs.get('projection_type') + '_c' + str(int(kwargs.get('c'))) + '_k' + str(int(kwargs.get('k'))) + '.dat'
                data = {'N': N_proj, 'time': t_proj}
                pickle_write(N_proj_filename,data)

            if save_N:
                logger.info('Saving N matrices from sampling!')
                N = [a[0] for a in x]
                x = [a[1] for a in x]
                filename = N_dir + 'N_' + matrix.name + '_sampling_s' + str(int(kwargs.get('s'))) + '_' + kwargs.get('projection_type') + '_c' + str(int(kwargs.get('c'))) + '_k' + str(int(kwargs.get('k'))) + '.dat'
                data = {'N': N, 'time': t}
                pickle_write(filename,data)

        else:
            raise ValueError('Please enter a valid sketch type!')
        return x, t

    elif objective == 'N':
        if sketch_type == 'projection':
            N_proj_filename = N_dir + 'N_' + matrix.name + '_projection_' + kwargs.get('projection_type') + '_c' + str(int(kwargs.get('c'))) + '_k' + str(int(kwargs.get('k'))) + '.dat'

            if load_N and os.path.isfile(N_proj_filename):
                logger.info('Found N matrices from projections, loading them!')
                result = pickle_load(N_proj_filename)
                N = result['N']
                t = result['time']
            else:
                t = time.time()
                projection = Projections(**kwargs)
                N = projection.execute(matrix, 'N')
                t = time.time() - t

                if save_N:
                    logger.info('Saving N matrices from projections!')
                    data = {'N': N, 'time': t}
                    pickle_write(N_proj_filename,data)

        elif sketch_type == 'sampling':
            s = kwargs.get('s')
            new_N_proj = 0
            new_N_samp = 0

            N_samp_filename = N_dir + 'N_' + matrix.name + '_sampling_s' + str(int(kwargs.get('s'))) + '_' + kwargs.get('projection_type') + '_c' + str(int(kwargs.get('c'))) + '_k' + str(int(kwargs.get('k'))) + '.dat'
            N_proj_filename = N_dir + 'N_' + matrix.name + '_projection_' + kwargs.get('projection_type') + '_c' + str(int(kwargs.get('c'))) + '_k' + str(int(kwargs.get('k'))) + '.dat'

            if load_N and os.path.isfile(N_samp_filename):
                logger.info('Found N matrices from sampling, loading them!')
                result = pickle_load(N_samp_filename)
                N = result['N']
                t = result['time']

            elif load_N and os.path.isfile(N_proj_filename):
                logger.info('Found N matrices from projections, loading them!')
                result = pickle_load(N_proj_filename)
                N_proj = result['N']
                t_proj = result['time']

                sampling = Sampling(N=N_proj)
                t = time.time()
                N = sampling.execute(matrix, 'N', s)
                t = time.time() - t + t_proj
                new_N_samp = 1

            else:
                t = time.time()
                projection = Projections(**kwargs)
                N_proj = projection.execute(matrix, 'N')
                t_proj = time.time() - t
                new_N_proj = 1

                t = time.time()
                sampling = Sampling(N=N_proj)
                N = sampling.execute(matrix, 'N', s)
                t = time.time() - t + t_proj
                new_N_samp = 1

            if save_N and new_N_proj:
                logger.info('Saving N matrices from projections!')
                data = {'N': N_proj, 'time': t_proj}
                pickle_write(N_proj_filename,data)

            if save_N and new_N_samp:
                logger.info('Saving N matrices from sampling!')
                data = {'N': N, 'time': t}
                pickle_write(N_samp_filename,data)

        else:
            raise ValueError('Please enter a valid sketch type!')
        return N, t
    else:
        raise ValueError('Please enter a valid objective!')
示例#15
0
path1 = "/Users/liumeiyu/Downloads/IMG_7575.JPG"
path2 = "/Users/liumeiyu/Downloads/test1.jpg"
path3 = "/Users/liumeiyu/Downloads/test2.jpg"

A = Histogram(path1)
B = Smooth(path1)
C = Change(path1)
D = Base(path1)
E = Binary(path1)
F = D_E(path1)
G = Warp(path1)
H = Cvt(path1)
K = Edge_detection(path1)
L = Segmentation(path1)
M = Mosaic(path3)
N = Sampling(path1)
P = Fusion(path1)

# A.img_histogram_trans()
# A.img_histogram()

# B.linear_smooth_np()
# B.linear_smooth()
# B.box_smooth()
# B.gaussian_smooth()
# B.median_smooth()
# B.median_smooth_x(5)

# C.fft_high_change(60)
# C.change_cv()
示例#16
0
def s_res_width(img, width, smooth_type):
    ratio_scale = (width * 1.0) / img.shape[1]
    height = int(ratio_scale * img.shape[0])
    nimg = s.smooth_resize(img, height, width, smooth_type)
    return nimg
 def test_sentence_size_sorted_sampling(self, test_input, expected):
     order_provider = Sampling.get_sentence_size_sorted_order(test_input, 2)
     order = next(order_provider)
     result = set(order)
     assert result == expected
示例#18
0
class Sensors:

    sampling_mpu_0x68 = None
    sampling_mpu_0x69 = None

    running = False

    def __init__(self, ):
        self.sampling_mpu_0x68 = Sampling(AK8963_ADDRESS, MPU9050_ADDRESS_68,
                                          MPU9050_ADDRESS_68, 1, GFS_1000,
                                          AFS_8G, AK8963_BIT_16,
                                          AK8963_MODE_C100HZ)
        self.sampling_mpu_0x69 = Sampling(AK8963_ADDRESS, MPU9050_ADDRESS_69,
                                          None, 1, GFS_1000, AFS_8G,
                                          AK8963_BIT_16, AK8963_MODE_C100HZ)
        # self.calibrate()

    def configure(self):
        self.sampling_mpu_0x68.configure()
        self.sampling_mpu_0x69.configure()

    def reset(self):
        self.sampling_mpu_0x68.reset()
        self.sampling_mpu_0x69.reset()
        # self.configure()

    def calibrate(self):
        self.sampling_mpu_0x68.calibrate()
        self.sampling_mpu_0x69.calibrate()
        # self.configure()

    def start(self):
        timeSync = time.time()
        self.sampling_mpu_0x68.startSampling(timeSync)
        self.sampling_mpu_0x69.startSampling(timeSync)
        self.running = True

    def stop(self):

        if self.running:
            self.sampling_mpu_0x68.stopSampling()
            self.sampling_mpu_0x69.stopSampling()
            self.running = False

    def showCurrent(self):
        def formatValue(array, index):

            format = "{: 4.17f}"

            if len(array) > index:
                return format.format(array[index])
            else:
                "        null        "

        def formatLabel(value):
            return value.center(20)

        data_0x68 = self.sampling_mpu_0x68.getAllData()
        data_0x69 = self.sampling_mpu_0x69.getAllData()

        print(
            "----------------------------------------------------------------------------------------------------------",
            "\n",
            "Time: ",
            time.time(),
            "\n",
            "----------------------------------------------------------------------------------------------------------",
            "\n",
            "MPU = ",
            formatLabel("0x68_master"),
            " | ",
            formatLabel("0x68_slave_of_0x68"),
            " | ",
            formatLabel("0x69_master"),
            " | ",
            formatLabel("0x68_slave_of_0x69"),
            "\n",
            "----------------------------------------------------------------------------------------------------------",
            "\n",
            "A_X = ",
            formatValue(data_0x68, 1),
            " | ",
            formatValue(data_0x68, 7),
            " | ",
            formatValue(data_0x69, 1),
            " | ",
            formatValue(data_0x69, 7),
            "\n",
            "A_Y = ",
            formatValue(data_0x68, 2),
            " | ",
            formatValue(data_0x68, 8),
            " | ",
            formatValue(data_0x69, 2),
            " | ",
            formatValue(data_0x69, 8),
            "\n",
            "A_Z = ",
            formatValue(data_0x68, 3),
            " | ",
            formatValue(data_0x68, 9),
            " | ",
            formatValue(data_0x69, 3),
            " | ",
            formatValue(data_0x69, 9),
            "\n",
            "----------------------------------------------------------------------------------------------------------",
            "\n",
            "G_X = ",
            formatValue(data_0x68, 4),
            " | ",
            formatValue(data_0x68, 10),
            " | ",
            formatValue(data_0x69, 4),
            " | ",
            formatValue(data_0x69, 10),
            "\n",
            "G_Y = ",
            formatValue(data_0x68, 5),
            " | ",
            formatValue(data_0x68, 11),
            " | ",
            formatValue(data_0x69, 5),
            " | ",
            formatValue(data_0x69, 11),
            "\n",
            "G_Z = ",
            formatValue(data_0x68, 6),
            " | ",
            formatValue(data_0x68, 12),
            " | ",
            formatValue(data_0x69, 6),
            " | ",
            formatValue(data_0x69, 12),
            "\n",
            "----------------------------------------------------------------------------------------------------------",
            "\n",
            "M_X = ",
            formatValue(data_0x68, 13),
            " | ",
            "        null        ",
            " | ",
            formatValue(data_0x69, 13),
            " | ",
            "        null        ",
            "\n",
            "M_Y = ",
            formatValue(data_0x68, 14),
            " | ",
            "        null        ",
            " | ",
            formatValue(data_0x69, 14),
            " | ",
            "        null        ",
            "\n",
            "M_Z = ",
            formatValue(data_0x68, 15),
            " | ",
            "        null        ",
            " | ",
            formatValue(data_0x69, 15),
            " | ",
            "        null        ",
            "\n",
            "----------------------------------------------------------------------------------------------------------",
            "\n",
        )
示例#19
0
        with open(path, "r", encoding=encoding) as f:
            for line in f:
                words = tokenize(line.strip())
                if len(words) < window_size + 1:
                    continue
                for i in range(len(words)):
                    example = (
                        words[max(0, i - window_size):i] +
                        words[min(i + 1, len(words)
                                  ):min(len(words), i + window_size) + 1],
                        words[i])
                    examples.append(Example.fromlist(example, fields))
        super(CBOWDataset, self).__init__(examples, fields, **kwargs)


if __name__ == '__main__':
    test_path = '/home/lightsmile/NLP/corpus/novel/test.txt'
    dataset = CBOWDataset(test_path, Fields)
    print(len(dataset))
    print(dataset[0])
    print(dataset[0].context)
    print(dataset[0].target)

    TARGET.build_vocab(dataset)

    from sampling import Sampling

    samp = Sampling(TARGET.vocab)

    print(samp.sampling(3))
示例#20
0
class Fisher(object):

    def __init__(self, network):
        """
        proper documentation
        """
        # remove dropout from the network if needed
        self.network = network
        self.f = None
        self.stochastic_layers = {}
        self.filter_layer()
        
        layers = self.network.layers
        for layer in layers:
            if layer.name in self.stochastic_layers:
                tmp = self.stochastic_layers[layer.name]
                setattr(layer, tmp[0], tmp[2])
        
        layers = self.network.layers
        intermediate_layers_input = []
        intermediate_layers_output = []
        for layer in layers:
            if re.match('merge_(.*)', layer.name):
                intermediate_layers_input.append(layer.input[0])
                intermediate_layers_output.append(layer.output)
            else:
                intermediate_layers_input.append(layer.input)
                intermediate_layers_output.append(layer.output)

        self.intermediate_input = Model(self.network.input, [input_ for input_ in intermediate_layers_input])
        self.intermediate_output = Model(self.network.input, [output_ for output_ in intermediate_layers_output])
        self.f = None
        layers = self.network.layers
        for layer in layers:
            if layer.name in self.stochastic_layers:
                tmp = self.stochastic_layers[layer.name]
                setattr(layer, tmp[0], tmp[1])
                
        self.dico_fisher = None
        self.sampler = None


    def filter_layer(self):
        layers = self.network.layers
        for layer in layers:
            if re.match('dropout_(.*)', layer.name):
                self.stochastic_layers[layer.name]=['p', layer.p, 0]
            if re.match('batchnormalization_(.*)', layer.name):
                self.stochastic_layers[layer.name]=['mode',layer.mode, 1]


    def fisher_information(self, X, Y, recompute=False):
        if not(recompute) and not(self.dico_fisher is None):
            return self.dico_fisher
        assert X.ndim==4, ("X must be a 4d tensor of shape (num_elem, num_channels, width, height) but has %d dimensions", X.ndim)
        assert Y.ndim<=2, ("Y contains the class label and should be of size (num_elem, 1) but has %d dimensions", Y.ndim)

        if Y.ndim==1:
            #Y = Y.dimshuffle((0, 'x'))
            Y = Y[:,None]

        if Y.ndim==2 and Y.shape[1]>1:
            Y = np.argmax(Y, axis=1)[:,None]
        
        layers = self.network.layers
        for layer in layers:
            if layer.name in self.stochastic_layers:
                tmp = self.stochastic_layers[layer.name]
                setattr(layer, tmp[0], tmp[2])

        dico_fisher, f = build_fisher(X, Y, self.network, self.intermediate_input, self.intermediate_output, f =self.f) # remove break
        self.f = f
        dico_conv_biases = build_fisher_biases(X, Y, self.network)
        for key in dico_conv_biases:
            dico_fisher[key] = dico_conv_biases[key]

        for layer in layers:
            if layer.name in self.stochastic_layers:
                tmp = self.stochastic_layers[layer.name]
                setattr(layer, tmp[0], tmp[1])
                
        # TO DO INVERSE LAYERS !!!!!!!
        self.dico_fisher = dico_fisher
        return self.dico_fisher
    
    def fisher_queries(self, X, Y):

        assert X.ndim==4, ("X must be a 4d tensor of shape (num_elem, num_channels, width, height) but has %d dimensions", X.ndim)
        assert Y.ndim<=2, ("Y contains the class label and should be of size (num_elem, 1) but has %d dimensions", Y.ndim)

        if Y.ndim==1:
            #Y = Y.dimshuffle((0, 'x'))
            Y = Y[:,None]

        if Y.ndim==2 and Y.shape[1]>1:
            Y = np.argmax(Y, axis=1)[:,None]
        
        layers = self.network.layers
        for layer in layers:
            if layer.name in self.stochastic_layers:
                tmp = self.stochastic_layers[layer.name]
                setattr(layer, tmp[0], tmp[2])

        dico_fisher = build_queries(X, Y, self.network, self.intermediate_input, self.intermediate_output) # remove break
        """
        dico_conv_biases = build_fisher_biases(X, Y, self.network)
        for key in dico_conv_biases:
            dico_fisher[key] = dico_conv_biases[key]

        for layer in layers:
            if layer.name in self.stochastic_layers:
                tmp = self.stochastic_layers[layer.name]
                setattr(layer, tmp[0], tmp[1])
                
        # TO DO INVERSE LAYERS !!!!!!!
        """
        return dico_fisher
    
    
    def save(self, repo, filename):
        with closing(open(os.path.join(repo, filename), 'wb')) as f:
            pkl.dump(self.dico_fisher, f, protocol=pkl.HIGHEST_PROTOCOL)
            
    def load(self, repo, filename):
        with closing(open(os.path.join(repo, filename), 'rb')) as f:
            self.dico_fisher = pkl.load(f)
        # temporary
        #self.dico_fisher=  dict([('conv_1_bias', self.dico_fisher['conv_1_bias']), ('conv_2_bias', self.dico_fisher['conv_2_bias'])])
        

    def build_mean(self):
        if self.dico_fisher is None:
            print('you need to compute the fisher information first')
            return {}

        layers = self.network.layers
        dico_mean = {}
        for layer in layers:
            if re.match('convolution(.*)', layer.name):
                nb_layer = layer.name.split('_')[1]
                # attention the bias is separated from the weigths
                W, b = layer.get_weights()
                dico_mean['conv_'+nb_layer] = W.flatten()
                dico_mean['conv_'+nb_layer+'_bias'] = b.flatten()
            elif re.match('dense_(.*)', layer.name):
                W, b = layer.get_weights()
                dico_mean[layer.name]=np.concatenate([W.flatten(), b.flatten()], axis=0)

            elif re.match('batchnormalization_(.*)', layer.name):
                gamma = layer.gamma.get_value()
                beta = layer.beta.get_value()
                dico_mean[layer.name]=np.concatenate([gamma.flatten(), beta.flatten()], axis=0)
        return dico_mean
    
    def copy_weights(self, model, dico_weights):
        layers = model.layers
        for layer in layers:
            if re.match('convolution(.*)input(.*)', layer.name):
                continue
            if re.match('convolution(.*)', layer.name):
                nb_layer = layer.name.split('_')[1]
                
                name_W = 'conv_'+nb_layer
                if name_W in dico_weights.keys():
                    W = layer.W;
                    W_value = dico_weights['conv_'+nb_layer].astype('float32')
                    W_shape = W.shape.eval()
                    W.set_value(W_value.reshape(W_shape))
                
                name_b = 'conv_'+nb_layer+'_bias'
                if name_b in dico_weights.keys():
                    b = layer.b
                    b_value = dico_weights['conv_'+nb_layer+'_bias'].astype('float32')
                    b_shape = b.shape.eval()
                    b.set_value(b_value.reshape(b_shape))
                
            elif re.match('dense_(.*)', layer.name):
                W = layer.W; b = layer.b
                W_shape = layer.W.shape.eval()
                b_shape = layer.b.shape.eval()
                if layer.name in dico_weights.keys():
                    params_W_b = dico_weights[layer.name].astype('float32')
                    split = len(params_W_b) - np.prod(b_shape)
                    W_value = params_W_b[:split]
                    b_value = params_W_b[split:]
                    W.set_value(W_value.reshape(W_shape))
                    b.set_value(b_value.reshape(b_shape))
                
            elif re.match('batchnormalization_(.*)', layer.name):
                gamma = layer.gamma
                beta = layer.beta
                gamma_shape = gamma.shape.eval()
                beta_shape = beta.shape.eval()
                gamma_split = np.prod(gamma_shape)
                if layer.name in dico_weights.keys():
                    params_gamma_beta = dico_weights[layer.name].astype('float32')
                    gamma_value = params_gamma_beta[:gamma_split]
                    beta_value = params_gamma_beta[gamma_split:]
                    gamma.set_value(gamma_value.reshape(gamma_shape))
                    beta.set_value(beta_value.reshape(beta_shape))


    def fisher_sample(self):
        if self.sampler is None:
            print("preprocessing")
            if self.dico_fisher is None:
                print('you need to compute the fisher information first')
                return
            self.sampler = Sampling(self.build_mean(), self.dico_fisher)
            print('sampling ok')
        config = self.network.get_config()
        if self.network.__class__.__name__=='Sequential':
            new_model = Sequential.from_config(config)
        else:
            new_model = Model.from_config(config)
        new_params = self.sampler.sample()
        """
        means = self.sampler.mean
        
        for key in means:
            if np.max(np.abs(means[key] - new_params[key]))==0:
                print key
        print('kikou')
        import pdb; pdb.set_trace()
        """
        #tmp_prob = self.sampler.prob(new_params)
        new_model.compile(loss=self.network.loss,
                          optimizer=str.lower(self.network.optimizer.__class__.__name__),
                          metrics = self.network.metrics)
        new_model.set_weights(self.network.get_weights())
        self.copy_weights(new_model, new_params)
        return new_model
    
    
    """
    def sample_ensemble(self, data, N=2, nb_classe=1):
        (X_train, Y_train), (X_test, Y_test) = data
        def evaluate(model, data_train):
            (X_train,Y_train) = data_train
            yPreds = model.predict(X_train)
            yPred = np.argmax(yPreds, axis=1)
            yTrue = Y_train
    
            return metrics.accuracy_score(yTrue, yPred)

        models = [self.network] + [self.fisher_sample() for i in range(N)]
        probabilities = []
        data_train = (X_train, Y_train)
        import pdb
        for model in models:
            var = evaluate(model, data_train)
            probabilities.append(var)
        #probabilities = [evaluate(model, (X_train, Y_train)) for model in models]
        
        probabilities /= sum(probabilities)

        yPreds_ensemble = np.mean([alpha*model.predict(X_test) for alpha, model in zip(probabilities,models)], axis=0)
        yPred = np.argmax(yPreds_ensemble, axis=1)
        yTrue = Y_test
        accuracy = metrics.accuracy_score(yTrue, yPred) * 100
        print("Accuracy : ", accuracy)
    """
    
    def sample_ensemble(self, data, N=2, nb_classe=1):
        (X_train, Y_train), (X_test, Y_test) = data
        models = [self.fisher_sample() for i in range(N)]
        N-=1
        yPreds_ensemble = np.array([model.predict(X_test) for model in models]) #(N+1, 10000, 10)
        committee = np.mean(yPreds_ensemble, axis=0) # shape(10000, 10)
        def kl_divergence(committee_proba, member_proba):
            # for a fixed i
            # for a fixed j
            yPred=[]
            for n in range(len(Y_test)):
                predict = []
                for i in range(10):
                    proba_i_C = committee_proba[n,i]
                    predict_i=0
                    for j in range(N+1):
                        proba_i_j = member_proba[j,n, i]
                        predict_i += np.log(proba_i_j)*np.log(proba_i_j/proba_i_C)
                    predict.append(predict_i)
                yPred.append(np.argmin(predict))
            return np.array(yPred).astype('uint8')
        yPred = kl_divergence(committee, yPreds_ensemble)
        yTrue = Y_test
        accuracy = metrics.accuracy_score(yTrue, yPred) * 100
        print("Accuracy : ", accuracy)
        
    def correlation_sampling(self, data, N=3, n=5):
        (X_train, Y_train), (X_test, Y_test) = data
        predict_ensemble = [self.network.predict(X_train)]
        ensemble_model = [self.network]
        for l in range(N):
            p_ensemble = np.mean(predict_ensemble, axis=0)
            models = [self.fisher_sample() for p in range(n)]
            predict = [model.predict(X_train) for model in models]
        
            def evaluate(model, data_train):
                (X_train,Y_train) = data_train
                yPreds = model.predict(X_train)
                yPred = np.argmax(yPreds, axis=1)
                yTrue = Y_train
    
                return metrics.accuracy_score(yTrue, yPred)

            def kl(p_network, p_model):
                kl=[]
                n = len(Y_train)
                for i in range(n):
                    kl_n = 0
                    for j in range(10):
                        kl_n += p_network[i,j]*np.log(p_network[i,j]/p_model[i,j])
                        kl.append(kl_n)
                return np.mean(kl)
        
            index = np.argmax([kl(p_ensemble, predict[j]) for j in range(n)])
            model = models[index]
            ensemble_model.append(model)
            predict_ensemble.append(predict[index])
    
        # Accuracy
        #Ypred = np.argmax(np.mean([network.predict(X_test) for network in [self.network, model]], axis=0), axis=1)
        Ypred = np.argmax(np.mean([model.predict(X_test) for model in ensemble_model]))
        Ytrue = Y_test
        print(metrics.accuracy_score(Ytrue, Ypred))
        
    def sample_ensemble(self, data, N=2, nb_classe=1):
        (X_train, Y_train), (X_test, Y_test) = data
        models = [self.network]; scores=[1.]
        for i in range(N):
            model, score = self.fisher_sample()
            models.append(model)
            scores.append(score)
        #models = [self.network] + [self.fisher_sample() for i in range(N)]
        min_prob = np.min(scores)
        scores -= min_prob
        scores = np.exp(scores)
        print(scores)
        """
        def evaluate(model, data_train):
            (X_train,Y_train) = data_train
            yPreds = model.predict(X_train)
            yPred = np.argmax(yPreds, axis=1)
            yTrue = Y_train
    
            return metrics.accuracy_score(yTrue, yPred)
        """
        #probabilities = [evaluate(model, (X_train, Y_train)) for model in models]
        #probabilities /= sum(probabilities)
        #yPreds_ensemble = np.array([alpha*model.predict(X_test) for alpha, model in zip(scores,models)]) #(N+1, 10000, 10)
        #yPreds_network = np.mean(self.network.predict(X_test), axis=0)
        yPred = np.argmax(np.mean(np.array([alpha*model.predict(X_test) for alpha, model in zip(scores,models)]), axis=0), axis=1) #(N+1, 10000, 10)
        """
        yPred = []
        for n in range(len(Y_test)):
            label = np.argmax(yPreds_ensemble[:,n,:]) % (N+1)
            yPred.append(label)

            if label !=Y_test[n]:

                print((np.argmax(yPreds_ensemble[0,n]), label, Y_test[n]))

        """
        yTrue = Y_test
        accuracy = metrics.accuracy_score(yTrue, yPred) * 100
        print("Accuracy : ", accuracy)