def __init__(self,sim=None, final_momentum=0.9, initial_momentum=0.5,momentum_switchover=5,lr_s=1e-3, lr_nu=1e-3 , lr_Yslack=1e-400, lr_theta=1e-2, maxIter=1000,initS=0.0,initSviaLineSearch=True,reg_lambda=0,tradeoffSLR=1,numReplicates=3,verbose=0): self.verbose=verbose self.numReplicates=numReplicates self.setSIM(sim) self.momentum_ = T.scalar() self.final_momentum=final_momentum; self.initial_momentum=initial_momentum;self.momentum_switchover=momentum_switchover;self.W=3;self.lr_s=lr_s;self.lr_theta=lr_theta;self.lr_Yslack=lr_Yslack;self.lr_nu=lr_nu;self.maxIter=maxIter;self.initS=initS self.replicateIndex_=T.vector(dtype='int64');self.lastGenerationIndex_=T.vector(dtype='int64') self.lrTheta_ = T.scalar();self.lrS_ = T.scalar();self.lrNu_ = T.scalar();self.lrYslack_ = T.scalar();self.target_ = T.matrix(dtype=floatX); self.times_ = T.fmatrix("times") self.Yslack__=theano.shared(np.asarray(0, dtype = floatX)); self.Theta__=theano.shared(np.asarray(0, dtype = floatX));self.n_ = T.scalar("n ") self.S__=theano.shared(np.asarray(self.initS, dtype = floatX)) self.c__=theano.shared(self.initC0, 'c') self.weightUpdateS__ = theano.shared(np.asarray(0, dtype = floatX)) self.weightUpdatec__ = theano.shared(np.zeros(self.numReplicates, dtype = floatX)) self.weightUpdateYslack__ = theano.shared(np.asarray(0, dtype = floatX)) self.weightUpdateTheta__ = theano.shared(np.asarray(0, dtype = floatX)) self.cost_, _ = theano.scan(lambda rep: ((self.target_[:self.lastGenerationIndex_[rep],rep] - Z(sig_(0.5*self.S__*self.times_[:self.lastGenerationIndex_[rep],rep] + self.c__[rep]),self.n_,self.Theta__)-self.Yslack__)**2).sum()/self.lastGenerationIndex_[rep] , sequences=self.replicateIndex_) ; self.cost_= self.cost_.sum() self.Loss_ = theano.function(inputs=[self.target_,self.times_,self.n_, self.lastGenerationIndex_,self.replicateIndex_], outputs=self.cost_) self.nu0pred_=Z(sig_(self.c__[0]),self.n_,self.Theta__)-self.Yslack__ # self.reg_=self.S__ # abs(self.Yslack__) reg_lambda*(-T.log(self.nu0pred_)) # for r in range(self.numReplicates): # self.reg_+=self.c__[r] self.obj_=self.cost_ #+ reg_lambda*self.reg_ self.gS_,self.gc_, self.gYslack_, self.gTheta_ = T.grad(self.obj_, [self.S__,self.c__, self.Yslack__,self.Theta__]) self.updatesS=[(self.weightUpdateS__, self.momentum_ * self.weightUpdateS__ - self.lrS_ * self.gS_),(self.S__, self.S__ + self.momentum_ * self.weightUpdateS__ - self.lrS_ * self.gS_)] self.updatesc=[(self.weightUpdatec__, self.momentum_ * self.weightUpdatec__ - self.lrNu_ * self.gc_),(self.c__, self.c__ + self.momentum_ * self.weightUpdatec__ - self.lrNu_ * self.gc_)] self.updatesYslack=[(self.weightUpdateYslack__, self.momentum_ * self.weightUpdateYslack__ - self.lrYslack_ * self.gYslack_),(self.Yslack__, self.Yslack__ + self.momentum_ * self.weightUpdateYslack__ - self.lrYslack_ * self.gYslack_)] self.updatesTheta=[(self.weightUpdateTheta__, self.momentum_ * self.weightUpdateTheta__ - self.lrTheta_ * self.gTheta_),(self.Theta__, self.Theta__ + self.momentum_ * self.weightUpdateTheta__ - self.lrTheta_ * self.gTheta_)] self.updates= self.updatesc +self.updatesS + self.updatesYslack +self.updatesTheta self.Objective_ = theano.function([ self.target_, self.lrS_, self.lrNu_, self.lrYslack_,self.lrTheta_, self.times_,self.momentum_,self.n_,self.lastGenerationIndex_,self.replicateIndex_], self.obj_, on_unused_input='warn',updates=self.updates,allow_input_downcast=True) self.tradeoffSLR=tradeoffSLR self.initYslack=0
def __init__(self,sim, final_momentum=0.9, initial_momentum=0.5,momentum_switchover=5,lr_s=1e-6, lr_nu=1e-2 , lr_Yslack=1e-2, maxIter=1000,initS=0.0,initSviaLineSearch=True): self.initSviaLineSearch=initSviaLineSearch self.sim=sim self.initYslack=0 self.n=self.sim.N*2 self.theta=self.sim.theta/(self.sim.L/self.sim.winSize); self.initC0 = np.ones(self.sim.numReplicates,dtype=floatX)*logit(sim.X0.min()) self.Times=np.tile(sim.getGenerationTimes(),(self.sim.numReplicates,1)).T.astype(np.float32) self.momentum_ = T.scalar() self.final_momentum=final_momentum; self.initial_momentum=initial_momentum;self.momentum_switchover=momentum_switchover;self.W=3;self.lr_s=lr_s;self.lr_theta=lr_Yslack;self.lr_nu=lr_nu;self.maxIter=maxIter;self.initS=initS self.lrS_ = T.scalar();self.lrNu_ = T.scalar();self.lrTheta_ = T.scalar();self.target_ = T.matrix(); self.times_ = T.fmatrix("times"); self.theta_ = T.scalar() self.Yslack__=theano.shared(np.asarray(0, dtype = floatX), 'theta');self.n_ = T.scalar("n ") self.S__=theano.shared(np.asarray(self.initS, dtype = floatX)) self.c__=theano.shared(self.initC0, 'c') self.weightUpdateS__ = theano.shared(np.asarray(0, dtype = floatX)) self.weightUpdatec__ = theano.shared(np.zeros(self.sim.numReplicates, dtype = floatX)) self.weightUpdateYslack__ = theano.shared(np.asarray(0, dtype = floatX)) self.pred_= Z(sig_(0.5*self.S__*self.times_ + self.c__),self.n_,self.theta_) + self.Yslack__ self.Feedforward_ = theano.function(inputs=[self.times_,self.n_,self.theta_], outputs=self.pred_) self.cost_=0 for j in range(self.sim.numReplicates): self.cost_ += 0.5*((self.target_[:,j] - self.pred_[:,j])**2).sum() self.Loss_ = theano.function(inputs=[self.target_,self.pred_], outputs=self.cost_) self.gS_,self.gc_, self.gYslack_ = T.grad(self.cost_, [self.S__,self.c__, self.Yslack__]) self.updatesS=[(self.weightUpdateS__, self.momentum_ * self.weightUpdateS__ - self.lrS_ * self.gS_),(self.S__, self.S__ + self.momentum_ * self.weightUpdateS__ - self.lrS_ * self.gS_)] self.updatesc=[(self.weightUpdatec__, self.momentum_ * self.weightUpdatec__ - self.lrNu_ * self.gc_),(self.c__, self.c__ + self.momentum_ * self.weightUpdatec__ - self.lrNu_ * self.gc_)] self.updatesYslack=[(self.weightUpdateYslack__, self.momentum_ * self.weightUpdateYslack__ - self.lrTheta_ * self.gYslack_),(self.Yslack__, self.Yslack__ + self.momentum_ * self.weightUpdateYslack__ - self.lrTheta_ * self.gYslack_)] self.updates= self.updatesc +self.updatesS + self.updatesYslack self.Objective_ = theano.function([ self.target_, self.lrS_, self.lrNu_, self.lrTheta_, self.times_,self.momentum_,self.n_,self.theta_], self.cost_, on_unused_input='warn',updates=self.updates,allow_input_downcast=True)