c = df.getDataConst(dbCsv.varConst, doNorm=True, rmNan=True) y = df.getDataTs('SMAP_AM', doNorm=True, rmNan=False) nx = x.shape[-1] + c.shape[-1] ny = 1 model = rnn.CudnnLstmModel(nx=nx, ny=ny, hiddenSize=64) lossFun = crit.RmseLoss() model = train.trainModel(model, x, y, c, lossFun, nEpoch=nEpoch, miniBatch=[100, 30]) modelName = 'test-LSTM' train.saveModel(outFolder, model, nEpoch, modelName=modelName) for k in dLst: sd = utils.time.t2dt(ty1[0]) - dt.timedelta(days=k) ed = utils.time.t2dt(ty1[1]) - dt.timedelta(days=k) df2 = hydroDL.data.dbCsv.DataframeCsv(rootDB=rootDB, subset='CONUSv4f1', tRange=[sd, ed]) obs = df2.getDataTs('SMAP_AM', doNorm=True, rmNan=False) model = rnn.LstmCloseModel(nx=nx, ny=ny, hiddenSize=64) lossFun = crit.RmseLoss() model = train.trainModel(model, (x, obs), y, c, lossFun,
# z2 = interp.interpNan1d(z2, mode='pre') xz2 = np.concatenate([x1, z2], axis=2) ny = 1 nx = x1.shape[-1] + c1.shape[-1] lossFun = crit.RmseLoss() # model1 = rnn.CudnnLstmModel(nx=nx, ny=ny, hiddenSize=64) # model1 = train.trainModel( # model1, x1, y1, c1, lossFun, nEpoch=nEpoch, miniBatch=(50, 365)) # train.saveModel(outFolder, model1, nEpoch, modelName='LSTM') model2 = rnn.CudnnLstmModel(nx=nx + 1, ny=ny, hiddenSize=64) model2 = train.trainModel( model2, xz1, y1, c1, lossFun, nEpoch=nEpoch, miniBatch=(50, 365)) train.saveModel(outFolder, model2, nEpoch, modelName='DA-1') model3 = rnn.CudnnLstmModel(nx=nx + 1, ny=ny, hiddenSize=64) model3 = train.trainModel( model3, xz2, y1, c1, lossFun, nEpoch=nEpoch, miniBatch=(50, 365)) train.saveModel(outFolder, model3, nEpoch, modelName='DA-7') if 'test' in doLst: df2 = camels.DataframeCamels(subset='all', tRange=[20050101, 20150101]) x2 = df2.getDataTS(varLst=camels.forcingLst, doNorm=True, rmNan=True) c2 = df2.getDataConst(varLst=camels.attrLstSel, doNorm=True, rmNan=True) yt2 = df2.getDataObs(doNorm=False, rmNan=False).squeeze() dfz1 = camels.DataframeCamels(subset='all', tRange=[20041231, 20141231]) z1 = dfz1.getDataObs(doNorm=True, rmNan=True) # z1 = interp.interpNan1d(z1, mode='pre')