Пример #1
0
                       dataset_eval=validData)

displayTime('Running Model', start_time, time.time())
saveHDF5(savef + '-final.h5', savedata)

if 'wiki' not in params['dataset']:
    evaluate = {}
    test_results = Evaluate.evaluateBound(model,
                                          dataset['test'],
                                          batch_size=params['batch_size'])
    evaluate['test_perp_0'] = test_results['perp_0']
    evaluate['test_perp_f'] = test_results['perp_f']
    print 'Test Bound: ', evaluate['test_perp_f']
    kname = 'valid_perp_f'
    # Work w/ the best model thus far
    epochMin, valMin, idxMin = getLowestError(savedata[kname])
    reloadFile = pfile.replace('-config.pkl', '') + '-EP' + str(
        int(epochMin)) + '-params.npz'
    print 'Loading model from epoch : ', epochMin  #reloadFile
    params['validate_only'] = True
    bestModel = Model(params,
                      paramFile=pfile,
                      reloadFile=reloadFile,
                      additional_attrs=additional_attrs)
    test_results = Evaluate.evaluateBound(bestModel,
                                          dataset['test'],
                                          batch_size=params['batch_size'])
    for k in test_results:
        evaluate[k + '_best'] = test_results[k]
    saveHDF5(savef + '-evaluate.h5', evaluate)
Пример #2
0
                           dataset['train'],
                           dataset['mask_train'],
                           epoch_start=0,
                           epoch_end=params['epochs'],
                           batch_size=params['batch_size'],
                           savefreq=params['savefreq'],
                           savefile=savef,
                           dataset_eval=dataset['valid'],
                           mask_eval=dataset['mask_valid'],
                           replicate_K=params['replicate_K'],
                           shuffle=False)
displayTime('Running DKF', start_time, time.time())
"""
Load the best DKF based on the validation error
"""
epochMin, valMin, idxMin = getLowestError(savedata['valid_bound'])
reloadFile = pfile.replace('-config.pkl', '') + '-EP' + str(
    int(epochMin)) + '-params.npz'
print 'Loading from : ', reloadFile
params['validate_only'] = True
dkf_best = DKF(params, paramFile=pfile, reloadFile=reloadFile)
additional = {}
savedata['bound_test_best'] = DKF_evaluate.evaluateBound(
    dkf_best,
    dataset['test'],
    dataset['mask_test'],
    S=2,
    batch_size=params['batch_size'],
    additional=additional)
savedata['bound_tsbn_test_best'] = additional['tsbn_bound']
savedata['ll_test_best'] = DKF_evaluate.impSamplingNLL(
Пример #3
0
# ========================================

# #change the dataset to one of ['jsb','nottingham','musedata','piano']
# DATASET= 'jsb'
# DATASET= 'ipython'
DATASET = 'synthetic'
DIR = '../expt/chkpt-' + DATASET + '/'
# DIR    = './chkpt-'+DATASET+'/'
# assert os.path.exists('../expt/chkpt-'+DATASET+'/'),'Run the shell files in ../expt first'
# prefix = 'DMM_lr-0_0008-dh-200-ds-100-nl-relu-bs-20-ep-2000-rs-600-rd-0_1-infm-R-tl-2-el-2-ar-2000_0-use_p-approx-rc-lstm-DKF-ar'
prefix = 'DMM_lr-0_0008-dh-200-ds-100-nl-relu-bs-20-ep-20-rs-600-rd-0_1-infm-R-tl-2-el-2-ar-2_0-use_p-approx-rc-lstm-uid'
# prefix = 'DMM_lr-0_0008-dh-40-ds-2-nl-relu-bs-200-ep-40-rs-80-rd-0_1-infm-R-tl-2-el-2-ar-2_0-use_p-approx-rc-lstm-uid'
stats = loadHDF5(os.path.join(DIR, prefix + '-final.h5'))
# stats  = loadHDF5(os.path.join('chkpt-ipython/DMM_lr-0_0008-dh-40-ds-2-nl-relu-bs-200-ep-40-rs-80-rd-0_1-infm-R-tl-2-el-2-ar-2_0-use_p-approx-rc-lstm-uid-EP30-stats.h5'))
epochMin, valMin, idxMin = getLowestError(stats['valid_bound'])
pfile = os.path.join(DIR, prefix + '-config.pkl')

params = readPickle(pfile, quiet=True)[0]
print 'Hyperparameters in: ', pfile, 'Found: ', os.path.exists(pfile)
EP = '-EP' + str(int(epochMin))
reloadFile = os.path.join(DIR, prefix + EP + '-params.npz')
print 'Model parameters in: ', reloadFile
#Don't load the training functions for the model since its time consuming
params['validate_only'] = True
dmm_reloaded = DMM(params, paramFile=pfile, reloadFile=reloadFile)

# forViz/chkpt-ipython/DMM_lr-0_0008-dh-40-ds-2-nl-relu-bs-200-ep-40-rs-80-rd-0_1-infm-R-tl-2-el-2-ar-2_0-use_p-approx-rc-lstm-uid-EP30-stats.h5

# expt/chkpt-synthetic/DMM_lr-0_0008-dh-200-ds-100-nl-relu-bs-20-ep-20-rs-600-rd-0_1-infm-R-tl-2-el-2-ar-2_0-use_p-approx-rc-lstm-uid-final.h5