예제 #1
0
파일: SeqKNN.py 프로젝트: alherit/kd-switch
    def __init__(self, alpha_label, theta0=None, powers = [.3,.5,.7,.9], 
                 lmbda = .9999, switching=True):

        self.switching = switching
        
        if theta0 is None:
            theta0=[1./alpha_label]*alpha_label
            print("default dist:",theta0)

        self.alpha_label = alpha_label

        self.theta0 = theta0

        self.lmbda = lp.LogWeightProb(lmbda)

        self.powers = powers
        self.N = len(powers)
        #cumprob for each expert
        self.cumprob =  np.repeat( lp.LogWeightProb(1.), self.N)
    
        self.weights = np.repeat( lp.LogWeightProb(1./self.N), self.N)

        self.trainingPoints = []
        self.trainingLabels = []

        if switching:
            # initialize switching
            # no switch before time 1
            self.w_theta0 = lp.LogWeightProb(1 - self.mu(1))
            # switch before time 1
            self.w_thetaz = lp.LogWeightProb(self.mu(1))
예제 #2
0
파일: SeqKNN.py 프로젝트: alherit/kd-switch
    def predictKNNlambda(self,point,label,k):
        if k>0:
        
            dists = pairwise.pairwise_distances(X=[point],Y=self.trainingPoints)[0]
    
            #1 neighbor = taking the smallest distance = 0th element
            k-=1
            kthElem = np.partition(dists, kth=k )[k]
           
            cond = dists<=kthElem
            
            dist = np.histogram(np.array(self.trainingLabels)[cond],bins=self.alpha_label,range=[0,self.alpha_label])[0] 
            dist = dist/np.sum(dist,dtype=float)
            #print(dist)
            prob = dist[label]
            

            prob = lp.LogWeightProb(prob)


        else:
            prob = self.predictTheta0(label)
                
        ## lambda mix as in the paper: mix with uniform prob (could be theta0 as well)
        prob = self.lmbda*prob + (lp.LogWeightProb(1.)-self.lmbda)*lp.LogWeightProb(1./self.alpha_label)

            
        return prob
예제 #3
0
파일: SeqGP.py 프로젝트: alherit/kd-switch
    def __init__(self, alpha_label, theta0=None, scales = np.logspace(-20,28,num=(28+20)/4+1, base=2), 
                  switching=True):

        self.switching = switching
        
        if theta0 is None:
            theta0=[1./alpha_label]*alpha_label
            print("default dist:",theta0)


        self.alpha_label = alpha_label

        self.theta0 = theta0


        self.scales = scales
        self.N = len(scales)
        #cumprob for each expert
        self.cumprob =  np.repeat( lp.LogWeightProb(1.), self.N)
    
        self.weights = np.repeat( lp.LogWeightProb(1./self.N), self.N)

        self.trainingPoints = []
        self.trainingLabels = []

        self.gpc = None

        if switching:
            # initialize switching
            # no switch before time 1
            self.w_theta0 = lp.LogWeightProb(1 - self.mu(1))
            # switch before time 1
            self.w_thetaz = lp.LogWeightProb(self.mu(1))
예제 #4
0
 def predict(self, point, label, update=None):
     jointprobs = [lp.LogWeightProb(d.probability(np.array(point))[0]) * lp.LogWeightProb(prior) for prior,d  in zip(self.theta0,self.dists)]
     
     sum_jp = lp.LogWeightProb(0.)
     
     for p in jointprobs:
         sum_jp+= p
     
     return jointprobs[label]/sum_jp
예제 #5
0
파일: SeqGP.py 프로젝트: alherit/kd-switch
    def predict(self, point, label, update):
        '''
        Give prob of label given point, using KNNBMASW
        '''
        
        
        prob_mixture = lp.LogWeightProb(0.)
        n = len(self.trainingPoints)
        #MIX USING POSTERIOR
        for i in range(self.N):
            
            prob_gp = self.predictGP(point,label,i)
            
            if update:
                self.cumprob[i] *= prob_gp
            prob_mixture += prob_gp * self.weights[i]

        if self.switching:
            theta0Prob = self.predictTheta0(label)
            switchProb = self.w_theta0*theta0Prob + self.w_thetaz*prob_mixture


        if update:
            self.trainingPoints.append(point)
            self.trainingLabels.append(label)

            acc = lp.LogWeightProb(0.)
            for i in range(self.N):
                self.weights[i] = self.cumprob[i] * lp.LogWeightProb(1./self.N)
                acc += self.weights[i]

            for i in range(self.N):
                self.weights[i] /= acc


            if self.switching:
                # update switching posteriors
                #we saw n points, we just predicted the n+1-th and, now we compute the weights for the n+2-th
                self.w_theta0 = self.w_theta0*theta0Prob* lp.LogWeightProb(1 - self.mu(n+2))
                self.w_thetaz = self.w_theta0*theta0Prob* lp.LogWeightProb(self.mu(n+2)) + self.w_thetaz * prob_mixture
                total = self.w_theta0 + self.w_thetaz
                self.w_theta0 /= total
                self.w_thetaz /= total

            ## compute GP posterior for next time
            if self.gpc is not None or len(np.unique(self.trainingLabels))==self.alpha_label :  ## fit needs all the classes 
                self.gpc = []
                for s in self.scales:
                    kernel = 1.0 * RBF(s)
                    self.gpc.append( GaussianProcessClassifier(kernel=kernel, optimizer=None, n_jobs=1).fit(self.trainingPoints, self.trainingLabels))



        if self.switching:
            return switchProb
        else:
            return prob_mixture
예제 #6
0
파일: SeqKNN.py 프로젝트: alherit/kd-switch
    def predict(self, point, label, update):
        '''
        Give prob of label given point, using KNNBMASW
        '''
        
        
        prob_mixture = lp.LogWeightProb(0.)
        n = len(self.trainingPoints)
        #MIX USING POSTERIOR
        for i in range(self.N):
            if n==0:
                k=0
            else:
                k = self.kofn_pow(n,self.powers[i])
            #print "n= ",n," k= ",k
            
            prob_knn = self.predictKNNlambda(point,label,k)
            
            if update:
                self.cumprob[i] *= prob_knn
            prob_mixture += prob_knn * self.weights[i]

        if self.switching:
            theta0Prob = self.predictTheta0(label)
            switchProb = self.w_theta0*theta0Prob + self.w_thetaz*prob_mixture


        if update:
            self.trainingPoints.append(point)
            self.trainingLabels.append(label)

            acc = lp.LogWeightProb(0.)
            for i in range(self.N):
                self.weights[i] = self.cumprob[i] * lp.LogWeightProb(1./self.N)
                acc += self.weights[i]

            for i in range(self.N):
                self.weights[i] /= acc


            if self.switching:
                # update switching posteriors
                #we saw n points, we just predicted the n+1-th and, now we compute the weights for the n+2-th
                self.w_theta0 = self.w_theta0*theta0Prob* lp.LogWeightProb(1 - self.mu(n+2))
                self.w_thetaz = self.w_theta0*theta0Prob* lp.LogWeightProb(self.mu(n+2)) + self.w_thetaz * prob_mixture
                total = self.w_theta0 + self.w_thetaz
                self.w_theta0 /= total
                self.w_thetaz /= total


        if self.switching:
            return switchProb
        else:
            return prob_mixture
예제 #7
0
파일: SeqGP.py 프로젝트: alherit/kd-switch
 def predictGP(self,point,label):
     if self.gpc is not None:
         probas =  self.gpc.predict_proba(point.reshape(1, -1))[0]
         prob = lp.LogWeightProb( probas[label] )
         
     else:
         prob = self.predictTheta0(label)
         
     return prob
예제 #8
0
 def predict(self, point, label, update):
  
          
     logprobs = [ -.5 * (np.dot(np.dot( point-self.mean[c] , self.sigmainv[c] ) , point-self.mean[c] ) +\
                self.logDetSigma[c]) + self.lnTheta0[c] for c  in [0,1] ]
     
     #print(logprobs)
     probs = softmax(logprobs)
     #print(probs)
     prob = lp.LogWeightProb( probs[label])
     return prob
예제 #9
0
    def __init__(self, depth=0, tree=None ):

            self.depth = depth
            self.tree = tree

            self.projDir = np.random.randint(0,self.tree.dim) 

            self.items = [] #samples observed in this node. if not self.tree.keep_items, they are moved to children node when created
            
            self.Children = None #references to children of this node

            self.pivot = None #splitting point

            self.counts = [0 for x in range(self.tree.alpha_label)]

            # CTProb is the prob on whole seq giving by doing cts on this node (=mixing with children)
            self.CTprob = lp.LogWeightProb(1.)

            ## for CTS
            self.wa = lp.LogWeightProb(.5)
            self.wb = lp.LogWeightProb(.5)
예제 #10
0
    def predict(self, point, label, update=True):
        '''
        Give prob of label given point, using CT*
        '''
        
        assert update # mandatory for kdSwitch
                
        prob_trees_mixture = lp.LogWeightProb(0.)
        #MIX USING POSTERIOR
        for i in range(self.J):
            prob_tree = self.trees[i].predictUpdateKDSwitch(point,label) 
            prob_trees_mixture += prob_tree * self.weights[i]

        acc = lp.LogWeightProb(0.)
        for i in range(self.J):
            self.weights[i] = self.trees[i].root.CTprob * lp.LogWeightProb(1./self.J)
            acc += self.weights[i]

        for i in range(self.J):
            self.weights[i] /= acc

        return prob_trees_mixture
예제 #11
0
    def __init__(self, J, dim, alpha_label, theta0=None, keepItems=False, max_rot_dim = None, local_alpha= True, ctw=False, global_rot=True):

        self.max_rot_dim = max_rot_dim
                
        self.local_alpha = local_alpha
        self.J = J
        self.trees = []
        self.ctw =ctw

        for i in range(J):
            self.trees.append( SeqTree(dim=dim, alpha_label=alpha_label, theta0=theta0,
                                       keepItems=keepItems, max_rot_dim = max_rot_dim, local_alpha=local_alpha, ctw=ctw)
            )
    
            # initialize with uniform weights
            self.weights = np.repeat( lp.LogWeightProb(1./J), J)
예제 #12
0
    def __init__(self, dim, alpha_label, theta0 = None, local_alpha=True,
                 keepItems = False, max_rot_dim = None, ctw=False):

            self.ctw = ctw
        
            self.max_rot_dim = max_rot_dim
            self.local_alpha = local_alpha

            if max_rot_dim > 0:
                #random rotation matrix
                print("generating random rotation matrix...")
                if dim > max_rot_dim:
                    print("using "+ str(max_rot_dim) + "random axes")
                    rot_axes = np.random.choice(range(dim),max_rot_dim,replace=False)
                    self.rot_mask = [True if i in rot_axes else False for i in range(dim)]  
                    rot_dim = max_rot_dim
                else:
                    self.rot_mask = None
                    rot_dim = dim
    
                self.R =  special_ortho_group.rvs(rot_dim)
            else:
                self.R = None
                print("No rotation.")


            self.theta0 = theta0

            self.dim = dim
            self.alpha_label =  alpha_label
            
            #set this value to avoid learning global proportion
            if self.theta0 is not None:
                self.P0Dist = [lp.LogWeightProb(p) for p in theta0]
            else:
                self.P0Dist = None

            self.n = 0

            #keep items in each node for post-hoc analysis: high memory consumption
            self.keepItems = keepItems

            self.root = SeqNode(depth=0,tree=self)
예제 #13
0
파일: SeqGP.py 프로젝트: alherit/kd-switch
 def predictTheta0(self,label):
     return lp.LogWeightProb(self.theta0[label])
예제 #14
0
    def predictUpdateKDSwitch(self,point,label, updateStructure=True):

        splitOccurredHere = False
        
        #STEP 1: UPDATE STRUCTURE if leaf (full-fledged kd tree algorithm)
        if updateStructure and self.Children is None:

            splitOccurredHere = True
            
            self.pivot = self.computeProj(point)
            
            self.Children = [
                    self.__class__(depth=self.depth+1, tree=self.tree),
                    self.__class__(depth=self.depth+1, tree=self.tree)
                    ]
            
            for p in self.items: ### make all the update steps, notice that we are not yet adding current point
                i = self._selectBranch(p.point)
                self.Children[i].items.append(cm.LabeledPoint(point=p.point,label=p.label))
                self.Children[i].counts[p.label] += 1


            i = self._selectBranch(point)
            self.Children[i].items.append(cm.LabeledPoint(point=point,label=label)) #add current point to children's collection but not to label count, it will be done after prediction


            #initialize children using already existing symbols
            for i in [0,1]:
                self.Children[i].CTprob = lp.LogWeightProb(log_wp=-cm.KT(self.Children[i].counts,alpha = self.tree.alpha_label))                  
                self.Children[i].wa *= self.Children[i].CTprob
                self.Children[i].wb *= self.Children[i].CTprob


            if not self.tree.keepItems:
                self.items = []  #all items have been sent down to children, so now should be empty
            
            
        #STEP 2 and 3: PREDICT AND UPDATE
        #save CTS_prob before update (lt_n = <n , cts prob up to previous symbol)        
        prob_CTS_lt_n = self.CTprob
        
        if self.depth ==0 and self.tree.P0Dist is not None: #known distribution
            prob_KT_next = self.tree.P0Dist[label]
        else:
            prob_KT_next = lp.LogWeightProb(cm.seqKT(self.counts,label,alpha = self.tree.alpha_label)) #labels in self.items are not used, just counts are used 


        #now we can UPDATE the label count       
        self.counts[label] += 1


        if self.Children is None: 
            #PREDICT
            self.CTprob*=prob_KT_next #just KT

            #UPDATE
            self.wa*= prob_KT_next
            self.wb*= prob_KT_next
        else:
            #PREDICT (and recursive UPDATE)
            i = self._selectBranch(point)
            pr = self.Children[i].predictUpdateKDSwitch(point,label, updateStructure=not splitOccurredHere) 
            
            self.CTprob = self.wa * prob_KT_next + self.wb * pr

            #UPDATE
            if self.tree.local_alpha:
                alpha_n_plus_1 = self.tree.alpha(sum(self.counts)+1)
            else:
                alpha_n_plus_1 = self.tree.alpha(self.tree.n+1)

            self.wa = alpha_n_plus_1 * self.CTprob + (lp.LogWeightProb(1.)- lp.LogWeightProb(2.)*alpha_n_plus_1)* self.wa * prob_KT_next;
            self.wb = alpha_n_plus_1 * self.CTprob + (lp.LogWeightProb(1.)- lp.LogWeightProb(2.)*alpha_n_plus_1)* self.wb * pr;

       
        prob_CTS_up_to_n = self.CTprob
        
        return prob_CTS_up_to_n / prob_CTS_lt_n 
예제 #15
0
 def alpha(self,n):
     if self.ctw: #never switches
         return lp.LogWeightProb(0.) 
     else:
         return lp.LogWeightProb(1.)/lp.LogWeightProb(n)
예제 #16
0
    def tst(self):

        #        #cumulate in log form
        cumCProb = logPr.LogWeightProb(1.)
        cumTheta0Prob = logPr.LogWeightProb(1.)
        self.pvalue = logPr.LogWeightProb(1.)

        # convert to log form
        alpha = logPr.LogWeightProb(self.alpha)

        i = 1

        reject = False

        while (self.max_time is None
               or time.process_time() - self.start_time <= self.max_time) and (
                   self.max_samples is None or i <= self.max_samples):

            if psutil.virtual_memory().percent > self.maxMem:
                print('Not enough memory to proceed. Used percentage %s. ' %
                      psutil.phymem_usage().percent)
                if not reject:
                    print("No difference detected so far.")
                break

            if self.max_samples is not None:
                if i > self.max_samples:
                    print('Max samples %s reached. ' % self.max_samples)
                    if not reject:
                        print("No difference detected so far.")
                    break

            if self.seqIndex is not None:
                if self.seqIndex < self.seqFeatures.shape[0]:
                    lp = cm.LabeledPoint(self.seqFeatures[self.seqIndex],
                                         self.seqLabels[self.seqIndex])
                    self.seqIndex += 1
                else:
                    lp = None
            else:
                lp = self.getSample()

            if lp is None:
                print('No more samples.')
                if not reject:
                    print("No difference detected so far.")
                break

            #print("label: ",lp.label, " point: ", lp.point)

            condProb = self.model.predict(lp.point, lp.label, update=True)
            theta0Prob = self.predictTheta0(lp.label)

            cumCProb *= condProb
            cumTheta0Prob *= theta0Prob

            self.pvalue = cumTheta0Prob / cumCProb  #min(1.0,math.pow(2,log_theta0Prob-log_CTWProb))
            n = len(self.processed) + 1
            if n % 10 == 0:
                print('n=', n, 'p-value', self.pvalue, 'alpha', self.alpha)

            nll = -cumCProb.getLogWeightProb() / n
            self.probs_file.write(
                str(time.process_time() - self.start_time) + " " +
                str(condProb) + " " + str(nll) + "\n")

            self.processed.append(lp)

            i += 1

            if not reject and self.pvalue <= alpha:
                reject = True
                n_reject = n
                p_value_reject = self.pvalue

                if self.stop_when_reject and reject:
                    break

        if not reject:
            n_ = n
            p_value_ = self.pvalue
        else:
            n_ = n_reject
            p_value_ = p_value_reject

        print('n=', n_, 'p-value', p_value_, 'alpha', self.alpha)

        print('n=', n, 'norm_log_loss', nll)

        result = etree.SubElement(self.XMLroot, 'result')
        etree.SubElement(result, 'stopping_time').text = str(n_)
        etree.SubElement(result, 'reject').text = str(reject)

        etree.SubElement(result, 'final_norm_log2_loss').text = str(nll)
        etree.SubElement(result, 'final_n').text = str(n)
예제 #17
0
 def predictTheta0(self, label):
     if self.theta0 is not None:
         return logPr.LogWeightProb(self.theta0[label])
     else:
         return logPr.LogWeightProb(0.)
예제 #18
0
    def __init__(self,
                 max_samples,
                 max_time,
                 alpha,
                 dataset,
                 seq_dataset_fname,
                 fname0,
                 fname1,
                 key0,
                 key1,
                 theta0,
                 maxTrials,
                 maxMem,
                 trial_n=None,
                 saveData=False,
                 probs_fname=None,
                 alpha_size=None,
                 data_gen_seed=None,
                 stop_when_reject=False,
                 synth_dim=None,
                 mean_gmd=None,
                 var_gvd=None,
                 estimate_median_pdist=False):

        self.tstData = None

        self.stop_when_reject = stop_when_reject

        if data_gen_seed == None:
            data_gen_seed = trial_n

        self.start_time = None

        self.alpha_size = alpha_size

        self.probs_file = open("./probs_seq_" + str(trial_n) + ".dat", "w")

        #max memory usage allowed
        self.maxMem = maxMem

        self.gadgetSampler = random.Random()
        self.gadgetSampler.seed(trial_n)

        self.dataSampler = random.Random()
        self.dataSampler.seed(data_gen_seed)

        #set global seed
        random.seed(trial_n)
        np.random.seed(trial_n)

        #self.min_samples = min_samples
        self.max_samples = max_samples
        self.max_time = max_time

        self.alpha = alpha

        self.maxTrials = maxTrials

        self.theta0 = theta0
        self.cumTheta0 = np.cumsum(self.theta0)

        #create xml output file
        self.XMLroot = etree.Element('experiment')
        params = etree.SubElement(self.XMLroot, 'parameters')
        etree.SubElement(params, 'maxMem').text = str(self.maxMem)
        etree.SubElement(params, 'max_time').text = str(self.max_time)
        etree.SubElement(params, 'max_samples').text = str(self.max_samples)
        etree.SubElement(params, 'alpha').text = str(self.alpha)
        etree.SubElement(params, 'maxTrials').text = str(self.maxTrials)
        etree.SubElement(params, 'theta0').text = str(self.theta0)
        etree.SubElement(params, 'gadget_sampler_seed').text = str(trial_n)
        etree.SubElement(params, 'global_seed').text = str(trial_n)

        xml_data = etree.SubElement(self.XMLroot, 'data')

        self.synthGen = None
        self.seqIndex = None

        self.unlimitedData = False

        if dataset is None and seq_dataset_fname is not None:
            data = pickle.load(open(seq_dataset_fname, "rb"))
            self.seqFeatures = data["features"]
            self.seqLabels = data["labels"]
            self.seqIndex = 0

            etree.SubElement(xml_data,
                             'seq_dataset_fname').text = seq_dataset_fname

            self.dim = self.seqFeatures.shape[1]
            etree.SubElement(xml_data, 'dim').text = str(self.dim)

        elif dataset is None and fname0 is not None and fname1 is not None:
            self.fname0 = fname0
            self.fname1 = fname1

            releaseFiles = True  #load datasets into memory and release files
            if releaseFiles:
                self.hf0 = h5py.File(self.fname0, 'r')
                self.hf1 = h5py.File(self.fname1, 'r')

                #use the first dataset of each file  ## [()] to get everything and put it in memory as numpy array (faster)
                self.datasets = {0: self.hf0[key0][()], 1: self.hf1[key1][()]}
                self.hf0.close()
                self.hf1.close()
            else:
                self.hf0 = h5py.File(self.fname0, 'r')
                self.hf1 = h5py.File(self.fname1, 'r')

                #use the first dataset of each file
                self.datasets = {0: self.hf0[key0], 1: self.hf1[key1]}

            etree.SubElement(xml_data, 'fname0').text = str(self.fname0)
            etree.SubElement(xml_data, 'fname1').text = str(self.fname1)
            etree.SubElement(xml_data, 'seed').text = str(data_gen_seed)

            self.dim = self.datasets[0].shape[1]
            etree.SubElement(xml_data, 'dim').text = str(self.dim)

        elif dataset.startswith("sklearn_"):
            from sklearn.utils import shuffle
            X, y = getattr(datasets, 'load' + dataset[7:])(return_X_y=True)
            X, y = shuffle(X, y, random_state=data_gen_seed)

            self.seqFeatures = X
            self.seqLabels = y
            self.seqIndex = 0
            self.dim = self.seqFeatures.shape[1]

            etree.SubElement(xml_data, 'sklearn').text = dataset[8:]

        elif dataset == "gmm":

            self.dim = synth_dim

            self.synthGen = NClassGaussianMixture(dim=synth_dim,
                                                  seed=data_gen_seed)

            self.unlimitedData = True
            self.X = None
            self.Y = None
            self.xi = None
            self.yi = None

            etree.SubElement(xml_data, 'dim').text = str(self.dim)

        elif dataset == "blobs":
            num_blobs = 4
            distance = 5
            stretch = 2
            angle = math.pi / 4.0

            self.synthGen = mySSBlobs(blob_distance=distance,
                                      num_blobs=num_blobs,
                                      stretch=stretch,
                                      angle=angle)

            self.sample_as_gretton(data_gen_seed)

            self.xi = 0  #next sample to consume
            self.yi = 0  #next sample to consume

            if saveData:
                if not os.path.exists("X"):
                    os.makedirs("X")
                if not os.path.exists("Y"):
                    os.makedirs("Y")

                n = self.max_samples / 2

                np.savetxt("X/blobs_X" + str(data_gen_seed) + ".dat",
                           self.X,
                           header=str(2) + " " + str(n),
                           comments='')
                np.savetxt("Y/blobs_Y" + str(data_gen_seed) + ".dat",
                           self.X,
                           header=str(2) + " " + str(n),
                           comments='')

            etree.SubElement(xml_data, 'seed').text = str(data_gen_seed)
            etree.SubElement(xml_data, 'num_blobs').text = str(num_blobs)
            etree.SubElement(xml_data, 'distance').text = str(distance)
            etree.SubElement(xml_data, 'stretch').text = str(stretch)
            etree.SubElement(xml_data, 'angle').text = str(angle)
            self.dim = 2
            etree.SubElement(xml_data, 'dim').text = str(self.dim)

        elif dataset == "gmd":
            if synth_dim is None:
                self.dim = 100
            else:
                self.dim = synth_dim

            meanshift = mean_gmd

            self.synthGen = mySSGaussMeanDiff(d=self.dim, my=meanshift)

            self.sample_as_gretton(data_gen_seed)

            self.xi = 0  #next sample to consume
            self.yi = 0  #next sample to consume

            etree.SubElement(xml_data, 'meanshift').text = str(meanshift)
            etree.SubElement(xml_data, 'dim').text = str(self.dim)

        elif dataset == "gvd":
            if synth_dim is None:
                self.dim = 50
            else:
                self.dim = synth_dim

            var_d1 = var_gvd

            self.synthGen = mySSGaussVarDiff(d=self.dim, var_d1=var_d1)

            self.sample_as_gretton(data_gen_seed)

            self.xi = 0  #next sample to consume
            self.yi = 0  #next sample to consume

            etree.SubElement(xml_data, 'vardiff').text = str(var_d1)
            etree.SubElement(xml_data, 'dim').text = str(self.dim)

        elif dataset == "sg":
            if synth_dim is None:
                self.dim = 50
            else:
                self.dim = synth_dim

            self.synthGen = mySSSameGauss(d=self.dim)

            self.sample_as_gretton(data_gen_seed)

            self.xi = 0  #next sample to consume
            self.yi = 0  #next sample to consume

            etree.SubElement(xml_data, 'same_gaussian').text = "default_value"
            etree.SubElement(xml_data, 'dim').text = str(self.dim)

        else:
            exit("Wrong dataset")

        #sets of row indexes already sampled
        #this is used to ensure sampling without replacement
        self.sampled = {}
        for i in range(self.alpha_size):
            self.sampled[i] = set()

        self.processed = list()

        self.model = None  # will be set later

        #here will be stored the p-value resulting from the two-sample test
        self.pvalue = logPr.LogWeightProb(1.)