class SCL_One_Cavity_Tracker_Model:
	def __init__(self,scl_long_tuneup_controller):
		self.scl_long_tuneup_controller = scl_long_tuneup_controller
		self.scl_accSeq = self.scl_long_tuneup_controller.scl_accSeq
		self.part_tracker = AlgorithmFactory.createParticleTracker(self.scl_accSeq)
		self.part_tracker.setRfGapPhaseCalculation(true)
		self.part_probe_init = ProbeFactory.createParticleProbe(self.scl_accSeq,self.part_tracker)
		self.scenario = Scenario.newScenarioFor(self.scl_accSeq)
		self.scenario.setSynchronizationMode(Scenario.SYNC_MODE_DESIGN)
		self.scenario.resync()
		# in the dictionary we will have 
		# cav_wrappers_param_dict[cav_wrapper] = [cavAmp,phase,[[gapLattElem,E0,ETL],...]]
		# E0 and ETL are parameters for all RF gaps
		self.cav_wrappers_param_dict = {}
		cav_wrappers = self.scl_long_tuneup_controller.cav_wrappers
		self.cav_amp_phase_dict = {}
		for cav_wrapper in cav_wrappers:
			amp = cav_wrapper.cav.getDfltCavAmp()
			phase = cav_wrapper.cav.getDfltCavPhase()
			self.cav_amp_phase_dict[cav_wrapper] = (amp,phase)
		#------ Make rf gap arrays for each cavity. 
		#------ The elements are IdealRfGap instances not AcceleratorNode. 
		#------ self.cavToGapsDict has {cav_name:[irfGaps]}
		rfGaps = self.scl_accSeq.getAllNodesWithQualifier(AndTypeQualifier().and((OrTypeQualifier()).or(RfGap.s_strType)))	
		self.cavToGapsDict = {}		
		for cav_wrapper in cav_wrappers:
			self.cavToGapsDict[cav_wrapper] = []
			for rfGap in rfGaps:
				if(rfGap.getId().find(cav_wrapper.cav.getId()) >= 0):
					irfGaps = self.scenario.elementsMappedTo(rfGap)
					self.cavToGapsDict[cav_wrapper].append(irfGaps[0])				
		#self.scenario.setModelInput(self.gap_first,RfGapPropertyAccessor.PROPERTY_PHASE,phase)
		#self.scenario.setModelInput(self.gap_first,RfGapPropertyAccessor.PROPERTY_ETL,val)
		#self.scenario.setModelInput(self.gap_first,RfGapPropertyAccessor.PROPERTY_E0,val)
		#self.scenario.setModelInput(quad,ElectromagnetPropertyAccessor,PROPERTY_FIELD,val)		
		#----------------------------------------------------------------
		self.scan_gd = BasicGraphData()
		self.harmonicsAnalyzer = HarmonicsAnalyzer(2)
		self.eKin_in = 185.6
		self.cav_amp = 14.0
		self.cav_phase_shift = 0.
		#------------------------
		self.active_cav_wrapper = null
		self.solver = null
		
	def restoreInitAmpPhases(self):
		cav_wrappers = self.scl_long_tuneup_controller.cav_wrappers
		for cav_wrapper in cav_wrappers:
			(amp,phase) = self.cav_amp_phase_dict[cav_wrapper]
			self.active_cav_wrapper.cav.updateDesignAmp(amp)
			self.active_cav_wrapper.cav.updateDesignPhase(phase)
		self.setActiveCavity(null)
		
	def getEkinAmpPhaseShift(self):
		return (self.eKin_in,self.cav_amp,self.cav_phase_shift)
		
	def setModelAmpPhaseToActiveCav(self,amp,phase,phase_shift):
		if(self.active_cav_wrapper != null):
			self.active_cav_wrapper.cav.updateDesignAmp(amp)
			self.active_cav_wrapper.cav.updateDesignPhase(phase-phase_shift)
			
	def getAvgGapPhase(self):
		#------------- calculate avg. RF gap phase -----------
		if(self.active_cav_wrapper == null): return 0.
		rf_gap_arr = self.cavToGapsDict[self.active_cav_wrapper]
		phase_rf_gaps_avg = 0.
		for irfGap in rf_gap_arr:
			phase_rf_gaps_avg += makePhaseNear(irfGap.getPhase(),0.)
		phase_rf_gaps_avg /= len(rf_gap_arr)
		phase_rf_gaps_avg = makePhaseNear((phase_rf_gaps_avg*180./math.pi)%360.,0.)
		return phase_rf_gaps_avg
			
	def getModelEnergyOut(self,eKin_in,amp,phase,phase_shift):
		if(self.active_cav_wrapper == null): return 0.
		self.setModelAmpPhaseToActiveCav(amp,phase,phase_shift)
		part_probe = ParticleProbe(self.part_probe_init)
		part_probe.setKineticEnergy(eKin_in*1.0e+6)
		self.scenario.setProbe(part_probe)	
		self.scenario.resync()
		self.scenario.run()
		return self.scenario.getTrajectory().finalState().getKineticEnergy()/1.0e+6
			
	def fillOutEneregyVsPhase(self,eKin_in,amp,phase_shift,phase_arr):
		self.scan_gd.removeAllPoints()
		if(self.active_cav_wrapper == null): return
		self.active_cav_wrapper.cav.updateDesignAmp(amp)
		self.scenario.resync()
		irfGap = self.cavToGapsDict[self.active_cav_wrapper][0]	
		for phase in phase_arr:
			part_probe = ParticleProbe(self.part_probe_init)
			part_probe.setKineticEnergy(eKin_in*1.0e+6)
			self.scenario.setProbe(part_probe)		
			#self.active_cav_wrapper.cav.updateDesignPhase(phase-phase_shift)
			#self.scenario.resync()
			irfGap.setPhase((phase-phase_shift)*math.pi/180.)
			self.scenario.run()
			eKin_out = self.scenario.getTrajectory().finalState().getKineticEnergy()/1.0e+6
			self.scan_gd.addPoint(phase,eKin_out)
		return self.scan_gd
			
	def getDiff2(self,eKin_in,amp,phase_shift):
		if(self.active_cav_wrapper == null): return 0.
		scan_gdExp = self.active_cav_wrapper.eKinOutPlot
		n_points = scan_gdExp.getNumbOfPoints()
		if(n_points <= 0): return 0.
		phase_arr = []
		for ip in range(n_points):
			phase_arr.append(scan_gdExp.getX(ip))
		scan_gd = self.fillOutEneregyVsPhase(eKin_in,amp,phase_shift,phase_arr)
		diff2 = 0.
		for ip in range(n_points):
			diff2 += (scan_gd.getY(ip) - scan_gdExp.getY(ip))**2
		diff2 /= n_points
		return diff2
			
	def setActiveCavity(self,cav_wrapper):
		self.active_cav_wrapper = cav_wrapper
		if(cav_wrapper != null):
			self.gap_list = cav_wrapper.cav.getGapsAsList()
			self.gap_first = self.gap_list.get(0)
			self.gap_last = self.gap_list.get(self.gap_list.size()-1)
			self.scenario.setStartNode(self.gap_first.getId())
			self.scenario.setStopNode(self.gap_last.getId())
		else:
			self.scenario.unsetStartNode()
			self.scenario.unsetStopNode()
			self.gap_first = null
			self.gap_last = null	
			self.gap_list = null
			
	def harmonicsAnalysisStep(self):
		if(self.active_cav_wrapper == null): return
		self.eKin_in = self.active_cav_wrapper.eKin_in
		self.cav_amp = 14.0
		self.cav_phase_shift = 0.
		#--------- first iteration
		self.getDiff2(self.eKin_in,self.cav_amp,self.cav_phase_shift)
		err = self.harmonicsAnalyzer.analyzeData(self.scan_gd)	
		harm_function = self.harmonicsAnalyzer.getHrmonicsFunction()
		energy_amp_test = harm_function.getParamArr()[1]
		energy_amp_exp = self.active_cav_wrapper.energy_guess_harm_funcion.getParamArr()[1]
		self.cav_amp = self.cav_amp*energy_amp_exp/energy_amp_test
		#--------- second iteration	
		self.getDiff2(self.eKin_in,self.cav_amp,self.cav_phase_shift)
		err = self.harmonicsAnalyzer.analyzeData(self.scan_gd)	
		harm_function = self.harmonicsAnalyzer.getHrmonicsFunction()
		energy_amp_test = harm_function.getParamArr()[1]
		energy_amp_exp = self.active_cav_wrapper.energy_guess_harm_funcion.getParamArr()[1]
		self.cav_amp = self.cav_amp*energy_amp_exp/energy_amp_test
		max_model_energy_phase = self.harmonicsAnalyzer.getPositionOfMax()
		max_exp_energy_phase = self.active_cav_wrapper.energy_guess_harm_funcion.findMax()
		self.cav_phase_shift = makePhaseNear(-(max_model_energy_phase - max_exp_energy_phase),0.)
		#print "debug model max=",max_model_energy_phase," exp=",max_exp_energy_phase," shift=",self.cav_phase_shift," amp=",self.cav_amp
		
	def fit(self):
		if(self.active_cav_wrapper == null): return
		variables = ArrayList()
		delta_hint = InitialDelta()
		#----- variable eKin_in
		var = Variable("eKin_in",self.eKin_in, - Double.MAX_VALUE, Double.MAX_VALUE)
		variables.add(var)
		delta_hint.addInitialDelta(var,0.3)
		#----- variable cavity amplitude
		var = Variable("cav_amp",self.cav_amp, - Double.MAX_VALUE, Double.MAX_VALUE)
		variables.add(var)
		delta_hint.addInitialDelta(var,self.cav_amp*0.01)
		#----- variable cavity phase offset
		var = Variable("phase_offset",self.cav_phase_shift, - Double.MAX_VALUE, Double.MAX_VALUE)
		variables.add(var)
		delta_hint.addInitialDelta(var,1.0)
		#-------- solve the fitting problem
		scorer = CavAmpPhaseScorer(self,variables)
		maxSolutionStopper = SolveStopperFactory.maxEvaluationsStopper(120) 
		self.solver = Solver(SimplexSearchAlgorithm(),maxSolutionStopper)
		problem = ProblemFactory.getInverseSquareMinimizerProblem(variables,scorer,0.0001)
		problem.addHint(delta_hint)
		self.solver.solve(problem)
		#------- get results
		trial = self.solver.getScoreBoard().getBestSolution()
		err2 = scorer.score(trial,variables)	
		[self.eKin_in,self.cav_amp,self.cav_phase_shift] = scorer	.getTrialParams(trial)	
		self.active_cav_wrapper.eKin_in = self.eKin_in
		self.active_cav_wrapper.designPhase = makePhaseNear(self.active_cav_wrapper.livePhase - self.cav_phase_shift,0.)
		self.active_cav_wrapper.eKin_err = math.sqrt(err2)
		cav_phase = self.active_cav_wrapper.livePhase
		self.active_cav_wrapper.eKin_out = self.getModelEnergyOut(self.eKin_in,self.cav_amp,cav_phase,self.cav_phase_shift)
		#print "debug cav=",self.active_cav_wrapper.alias," shift=",self.cav_phase_shift," amp=",self.cav_amp," err2=",	math.sqrt(err2)," ekinOut=",	self.active_cav_wrapper.eKin_out		
		#----- this defenition of the avg. gap phase will be replaced by another with self.model_eKin_in
		self.active_cav_wrapper.avg_gap_phase = self.getAvgGapPhase()
		self.active_cav_wrapper.designAmp = self.cav_amp
		self.solver = null
		#----make theory graph plot
		x_arr = []
		y_arr = []
		for i in range(self.scan_gd.getNumbOfPoints()):
			phase = self.scan_gd.getX(i)
			y = self.scan_gd.getY(i)
			x_arr.append(phase)
			y_arr.append(y)
		self.active_cav_wrapper.eKinOutPlotTh.addPoint(x_arr,y_arr)			
		
	def stopFitting(self):
		if(self.solver != null):
			self.solver.stopSolving()