if __name__ == '__main__': import os from sklearn.externals import joblib from slda_utils__param_manager import ( flatten_to_differentiable_param_vec, unflatten_to_common_param_dict, ) import slda_loss__autograd # Simplest possible test # Load the toy bars dataset # Load "true" bars topics # Compute the loss dataset_path = os.path.expandvars("$PC_REPO_DIR/datasets/toy_bars_3x3/") dataset = slda_utils__dataset_manager.load_dataset(dataset_path, split_name='train') n_batches = 100 # Load "true" 4 bars dim_P = dict(n_states=4, n_labels=1, n_vocabs=9) model_hyper_P = dict(alpha=1.1, tau=1.1, lambda_w=0.001, weight_x=1.0, weight_y=1.0) GP = joblib.load( os.path.join(dataset_path, "good_loss_x_K4_param_dict.dump")) for key in GP.keys(): if key not in ['topics_KV', 'w_CK']: del GP[key] param_vec = flatten_to_differentiable_param_vec(**GP) GPA = unflatten_to_common_param_dict(param_vec, **dim_P) GPB = unflatten_to_common_param_dict__tf(param_vec, **dim_P)
return ans_dict else: return loss_ttl if __name__ == '__main__': import os from sklearn.externals import joblib from slda_utils__dataset_manager import load_dataset # Simplest possible test # Load the toy bars dataset # Load "true" bars topics # Compute the loss dataset_path = os.path.expandvars("$PC_REPO_DIR/datasets/toy_bars_3x3/") dataset = load_dataset(dataset_path, split_name='train') # Load "true" 4 bars GP = joblib.load( os.path.join(dataset_path, "good_loss_x_K4_param_dict.dump")) topics_KV = GP['topics_KV'] w_CK = GP['w_CK'] loss_dict = calc_loss__slda(dataset=dataset, topics_KV=topics_KV, w_CK=w_CK, nef_alpha=1.1, tau=1.1, lambda_w=0.001, return_dict=True) print(loss_dict['summary_msg'])