示例#1
0
    dkf = DKF(params, paramFile=pfile)
displayTime('Building dkf', start_time, time.time())

# Set save prefix
savef = os.path.join(params['savedir'], params['unique_id'])
print 'Savefile: ', savef
start_time = time.time()

# Learn the model (see stinfmodel/learning.py)
savedata = DKF_learn.learn(dkf,
                           dataset['train'],
                           dataset['mask_train'],
                           epoch_start=0,
                           epoch_end=params['epochs'],
                           batch_size=params['batch_size'],
                           savefreq=params['savefreq'],
                           savefile=savef,
                           dataset_eval=dataset['val'],
                           mask_eval=dataset['mask_val'],
                           replicate_K=params['replicate_K'],
                           shuffle=False,
                           cond_vals_train=train_cond_vals,
                           cond_vals_eval=val_cond_vals)
displayTime('Running DKF', start_time, time.time())

# Evaluate bound on test set (see stinfmodel/evaluate.py)
savedata['bound_test'] \
   = DKF_evaluate.evaluateBound(dkf, dataset['test'], dataset['mask_test'],
                                batch_size=params['batch_size'],
                                cond_vals=val_cond_vals)
saveHDF5(savef + '-final.h5', savedata)
print 'Test Bound: ', savedata['bound_test']
示例#2
0
if os.path.exists(reloadFile):
    pfile=params.pop('paramFile')
    assert os.path.exists(pfile),pfile+' not found. Need paramfile'
    print 'Reloading trained model from : ',reloadFile
    print 'Assuming ',pfile,' corresponds to model'
    dkf  = DKF(params, paramFile = pfile, reloadFile = reloadFile) 
else:
    pfile= params['savedir']+'/'+params['unique_id']+'-config.pkl'
    print 'Training model from scratch. Parameters in: ',pfile
    dkf  = DKF(params, paramFile = pfile)
displayTime('Building dkf',start_time, time.time())


savef     = os.path.join(params['savedir'],params['unique_id']) 
print 'Savefile: ',savef
start_time= time.time()
savedata = DKF_learn.learn(dkf, dataset['train'], dataset['mask_train'], 
                                epoch_start =0 , 
                                epoch_end = params['epochs'], 
                                batch_size = params['batch_size'],
                                savefreq   = params['savefreq'],
                                savefile   = savef,
                                dataset_eval=dataset['valid'],
                                mask_eval  = dataset['mask_valid'],
                                replicate_K = 5
                                )
displayTime('Running DKF',start_time, time.time())
#Save file log file
saveHDF5(savef+'-final.h5',savedata)
#import ipdb;ipdb.set_trace()