def __init__(self, n_components=1, n_iter=0, tol=1E-5, method='astroML', labels = None, random_state = None, V=None, mu=None, weights=None, filename=None, w=0.): if method != 'astroML' and method !='Bovy': raise ValueError("Fitting method must be 'astroML' or " + "'Bovy'.") if filename is not None: self.read_model(filename) else: self.n_components = n_components if n_iter != 0: self.n_iter = n_iter else: if method=='astroML': self.n_iter = 100 else: self.n_iter = 10**9 self.tol = tol self.random_state = random_state self.method = method self.labels = labels self.w = w # Model parameters. These are set by the fit() method but # can be set at initialization. self.V = V self.mu = mu self.weights = weights self.GMM=astroML_XDGMM(n_components, max_iter=self.n_iter,tol=tol, random_state=random_state) self.GMM.mu=mu self.GMM.V=V self.GMM.alpha=weights
def read_model(self,filename): """Read the parameters of the model from a file Read the parameters of a model from a file in the format saved by save_model and set the parameters of this model to those from the file Parameters ---------- filename: string Name of the file to read from. """ infile=open(filename,'r') inlines=infile.readlines() infile.close() params=inlines[2].split(',') self.n_components=int(params[0]) self.n_iter=int(params[1]) self.tol=float(params[2]) self.method=params[3] if params[4]=='None\n': self.random_state=None else: self.random_state=int(params[4]) if inlines[4]=='No labels\n': self.labels=None else: labels_line=inlines[4].split(',') labels=[] for label in labels_line: labels.append(label.split()[0]) labels=np.array(labels) self.labels=labels weight_line=inlines[6].split(',') weights=[] for weight in weight_line: weights.append(float(weight)) weights=np.array(weights) self.weights=weights mu=[] for i in range(8,len(inlines)): if inlines[i]=='# covars\n': nextidx=i+1 break tmp=[] line=inlines[i].split(',') for j in range(len(line)): tmp.append(float(line[j])) mu.append(np.array(tmp)) mu=np.array(mu) self.mu=mu V=[] currV=[] i=nextidx while i<len(inlines): if inlines[i]=='#\n': V.append(np.array(currV)) currV=[] i+=1 line=inlines[i].split(',') tmp=[] for j in range(len(line)): tmp.append(float(line[j])) currV.append(np.array(tmp)) i+=1 V.append(np.array(currV)) V=np.array(V) self.V=V self.GMM=astroML_XDGMM(n_components=self.n_components, max_iter=self.n_iter,tol=self.tol, random_state=self.random_state) self.GMM.mu=self.mu self.GMM.V=self.V self.GMM.alpha=self.weights