class OSM_strategy(Policy): def __init__(self, observation_space, action_space, config): Policy.__init__(self, observation_space, action_space, config) self.osm = OSM(config['alpha'], config['gamma'], config['blocks']) self.osm.MDP_matrix_init() P, R = self.osm.get_MDP_matrix() solver = mdptoolbox.mdp.PolicyIteration(P, R, 0.99) solver.run() self.blocks = config['blocks'] self.optimal_policy = solver.policy def OSM_act(self, s): curr_s = list(s) if s[3] == constants.NORMAL: curr_s[3] = 'normal' elif s[3] == constants.FORKING: curr_s[3] = 'forking' else: curr_s[3] = 'catch up' smaller_state = curr_s[:2] + [curr_s[3]] smaller_state = tuple(smaller_state) if curr_s[0] >= self.blocks or curr_s[1] >= self.blocks: if curr_s[0] > curr_s[1]: return 1 else: return 0 if smaller_state in self.osm._state_dict: return self.optimal_policy[self.osm._name_to_index(smaller_state)] else: if curr_s[0] > curr_s[1]: return 1 else: return 0 def compute_actions(self, obs_batch, state_batches, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, **kwargs): actions = [] for obs in obs_batch: a = int(round(obs[0])) h = int(round(obs[1])) o = int(round(obs[2])) f = int(round(obs[3])) actions.append(self.OSM_act([a, h, o, f])) return actions, [], {} def learn_on_batch(self, samples): pass def get_weights(self): pass def set_weights(self, weights): pass
EXPR_PARAM_PICKLE_FILE, EXPR_PROGRESS_FILE, EXPR_RESULT_FILE) from functools import reduce from itertools import (chain, takewhile) from ray.rllib.agents.ppo import PPOTrainer #from OSM import OSM import os import csv import math import time import constants import mdptoolbox from OSM import OSM from BitcoinEnv import BitcoinEnv from bitcoin_game import OSM_strategy ''' blocks = 5 osm_space = spaces.Box(low=np.zeros(4), high=np.array([blocks + 4, blocks + 4, blocks + 4, 3.])) osm = OSM_strategy(osm_space, spaces.Discrete(4), {'alpha':.15, 'gamma':0,'blocks':5}) print(osm.OSM_act([1, 1, 1, 0])) ''' osm = OSM(.15, .5, 5) osm.MDP_matrix_init() P, R = osm.get_MDP_matrix() solver = mdptoolbox.mdp.PolicyIteration(P, R, 0.99) solver.run() print(solver.V)