示例#1
0
文件: xdgmm.py 项目: broulston/XDGMM
 def __init__(self, n_components=1, n_iter=0, tol=1E-5,
              method='astroML', labels = None, random_state = None, 
              V=None, mu=None, weights=None, filename=None, w=0.):
     
     if method != 'astroML' and method !='Bovy':
         raise ValueError("Fitting method must be 'astroML' or " +
                          "'Bovy'.")
     
     if filename is not None:
         self.read_model(filename)
     
     else:       
         self.n_components = n_components
         if n_iter != 0: self.n_iter = n_iter
         else:
             if method=='astroML': self.n_iter = 100
             else: self.n_iter = 10**9
         self.tol = tol
         self.random_state = random_state
         self.method = method
         self.labels = labels
         self.w = w
         
         # Model parameters. These are set by the fit() method but
         # can be set at initialization.
         self.V = V
         self.mu = mu
         self.weights = weights
         
         self.GMM=astroML_XDGMM(n_components,
                                max_iter=self.n_iter,tol=tol,
                                random_state=random_state)
         self.GMM.mu=mu
         self.GMM.V=V
         self.GMM.alpha=weights
示例#2
0
文件: xdgmm.py 项目: broulston/XDGMM
 def read_model(self,filename):
     """Read the parameters of the model from a file
     
     Read the parameters of a model from a file in the format saved
         by save_model and set the parameters of this model to those
         from the file
     
     Parameters
     ----------
     filename: string
         Name of the file to read from.
     """
     infile=open(filename,'r')
     inlines=infile.readlines()
     infile.close()
     
     params=inlines[2].split(',')
     self.n_components=int(params[0])
     self.n_iter=int(params[1])
     self.tol=float(params[2])
     self.method=params[3]
     if params[4]=='None\n': self.random_state=None
     else: self.random_state=int(params[4])
     
     if inlines[4]=='No labels\n': self.labels=None
     else:
         labels_line=inlines[4].split(',')
         labels=[]
         for label in labels_line:
             labels.append(label.split()[0])
         labels=np.array(labels)
         self.labels=labels
     
     weight_line=inlines[6].split(',')
     weights=[]
     for weight in weight_line:
         weights.append(float(weight))
     weights=np.array(weights)
     self.weights=weights
     
     mu=[]
     for i in range(8,len(inlines)):
         if inlines[i]=='# covars\n':
             nextidx=i+1
             break
         
         tmp=[]
         line=inlines[i].split(',')
         for j in range(len(line)):
             tmp.append(float(line[j]))
         mu.append(np.array(tmp))
     mu=np.array(mu)
     self.mu=mu
     
     V=[]
     currV=[]
     i=nextidx
     while i<len(inlines):
         if inlines[i]=='#\n':
             V.append(np.array(currV))
             currV=[]
             i+=1
         
         line=inlines[i].split(',')
         tmp=[]
         for j in range(len(line)):
             tmp.append(float(line[j]))
         currV.append(np.array(tmp))
         i+=1
     V.append(np.array(currV))
     V=np.array(V)
     self.V=V
     
     self.GMM=astroML_XDGMM(n_components=self.n_components,
                            max_iter=self.n_iter,tol=self.tol,
                            random_state=self.random_state)
     self.GMM.mu=self.mu
     self.GMM.V=self.V
     self.GMM.alpha=self.weights