示例#1
0
class OSM_strategy(Policy):
    def __init__(self, observation_space, action_space, config):
        Policy.__init__(self, observation_space, action_space, config)
        self.osm = OSM(config['alpha'], config['gamma'], config['blocks'])
        self.osm.MDP_matrix_init()
        P, R = self.osm.get_MDP_matrix()
        solver = mdptoolbox.mdp.PolicyIteration(P, R, 0.99)
        solver.run()
        self.blocks = config['blocks']
        self.optimal_policy = solver.policy

    def OSM_act(self, s):
        curr_s = list(s)
        if s[3] == constants.NORMAL:
            curr_s[3] = 'normal'
        elif s[3] == constants.FORKING:
            curr_s[3] = 'forking'
        else:
            curr_s[3] = 'catch up'
        smaller_state = curr_s[:2] + [curr_s[3]]
        smaller_state = tuple(smaller_state)
        if curr_s[0] >= self.blocks or curr_s[1] >= self.blocks:
            if curr_s[0] > curr_s[1]:
                return 1
            else:
                return 0
        if smaller_state in self.osm._state_dict:
            return self.optimal_policy[self.osm._name_to_index(smaller_state)]
        else:
            if curr_s[0] > curr_s[1]:
                return 1
            else:
                return 0

    def compute_actions(self,
                        obs_batch,
                        state_batches,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        **kwargs):
        actions = []
        for obs in obs_batch:
            a = int(round(obs[0]))
            h = int(round(obs[1]))
            o = int(round(obs[2]))
            f = int(round(obs[3]))
            actions.append(self.OSM_act([a, h, o, f]))
        return actions, [], {}

    def learn_on_batch(self, samples):
        pass

    def get_weights(self):
        pass

    def set_weights(self, weights):
        pass
示例#2
0
文件: testOSM.py 项目: wuwuz/SquirRL
                             EXPR_PARAM_PICKLE_FILE, EXPR_PROGRESS_FILE,
                             EXPR_RESULT_FILE)
from functools import reduce
from itertools import (chain, takewhile)
from ray.rllib.agents.ppo import PPOTrainer
#from OSM import OSM
import os
import csv
import math
import time
import constants
import mdptoolbox

from OSM import OSM
from BitcoinEnv import BitcoinEnv
from bitcoin_game import OSM_strategy
'''
blocks = 5
osm_space = spaces.Box(low=np.zeros(4), 
                high=np.array([blocks + 4, blocks + 4, blocks + 4, 3.]))
osm = OSM_strategy(osm_space, spaces.Discrete(4), {'alpha':.15, 'gamma':0,'blocks':5})
print(osm.OSM_act([1, 1, 1, 0]))
'''

osm = OSM(.15, .5, 5)
osm.MDP_matrix_init()
P, R = osm.get_MDP_matrix()
solver = mdptoolbox.mdp.PolicyIteration(P, R, 0.99)
solver.run()
print(solver.V)