def main(): # Parse arguments parser = ArgumentParser() parser.add_argument( '-cur_state_pkl', required=True, help='Path to the pickle file holding the current state.') parser.add_argument( '-prev_state_pkl', required=True, help='Path to the pickle file holding the previous state.') parser.add_argument('-global_step_db', required=True, help='Path to db holding global step results.') args, unknown = parser.parse_known_args() # raise ValueError(args) fname_cur_state = path.abspath(args.cur_state_pkl) fname_prev_state = path.abspath(args.prev_state_pkl) global_db = path.abspath(args.global_step_db) # Load local state local_state = StateData.load(fname_prev_state).data # Load global node output global_out = LogRegrIter_Glob2Loc_TD.load(global_db) # Run algorithm local iteration step local_state, local_out = logregr_local_iter(local_state=local_state, local_in=global_out) # Save local state local_state.save(fname=fname_cur_state) # Return local_out.transfer()
def logregr_global_iter(global_state, global_in): # Unpack global state n_obs = global_state['n_obs'] n_cols = global_state['n_cols'] ll_old = global_state['ll'] coeff = global_state['coeff'] iter = global_state['iter'] y_val_dict = global_state['y_val_dict'] schema_X = global_state['schema_X'] schema_Y = global_state['schema_Y'] # Unpack global input ll_new, grad, hess = global_in.get_data() # Compute new coefficients coeff = np.dot(np.linalg.inv(hess), grad) # Update termination quantities delta = abs(ll_new - ll_old) iter += 1 # Pack state and results global_state = StateData(n_obs=n_obs, n_cols=n_cols, ll=ll_new, coeff=coeff, delta=delta, iter=iter, y_val_dict=y_val_dict, schema_X=schema_X, schema_Y=schema_Y) global_out = LogRegrIter_Glob2Loc_TD(coeff) return global_state, global_out
def logregr_global_init(global_in): n_obs, n_cols, y_val_dict, schema_X, schema_Y = global_in.get_data() if n_obs == 0: raise ExaremeError('The selected variables contain 0 datapoints.') # Init vars ll = - 2 * n_obs * np.log(2) coeff = np.zeros(n_cols) iter = 0 # Pack state and results global_state = StateData(n_obs=n_obs, n_cols=n_cols, ll=ll, coeff=coeff, iter=iter, y_val_dict=y_val_dict, schema_X=schema_X, schema_Y=schema_Y) global_out = LogRegrIter_Glob2Loc_TD(coeff) return global_state, global_out