Ejemplo n.º 1
0
def restore_trainer(checkpoint_path):
    env_config = {
        "action_space": Discrete(12),
        "observation_space": Box(-1000.0, 1000.0, (47, ), np.float32),
    }
    config = {"explore": False, "env_config": env_config}
    _trainer = MARWILTrainer(config=config, env=RandomEnv)
    _trainer.restore(checkpoint_path)
    return _trainer
Ejemplo n.º 2
0
class MARWILrl(object):
    def __init__(self, env, env_config, config):
        self.config = config
        self.config['env_config'] = env_config
        self.env = env(env_config)
        self.agent = MARWILTrainer(config=self.config, env=env)

    def fit(self, checkpoint=None):
        if checkpoint is None:
            checkpoint = os.path.join(os.getcwd(), 'data/checkpoint_rl.pkl')
        for idx in trange(5):
            result = self.agent.train()
            LOGGER.warning('result: ', result)
            if (idx + 1) % 5 == 0:
                LOGGER.warning('Save checkpoint at: {}'.format(idx + 1))
                state = self.agent.save_to_object()
                with open(checkpoint, 'wb') as fp:
                    pickle.dump(state, fp, protocol=pickle.HIGHEST_PROTOCOL)
        return result

    def predict(self, checkpoint=None):
        if checkpoint is not None:
            with open(checkpoint, 'rb') as fp:
                state = pickle.load(fp)
            self.agent.restore_from_object(state)
        done = False
        episode_reward = 0
        obs = self.env.reset()
        actions = []
        while not done:
            action = self.agent.compute_action(obs)
            actions.append(action)
            obs, reward, done, info = self.env.step(action)
            episode_reward += reward
        results = {'action': actions, 'reward': episode_reward}
        return results
Ejemplo n.º 3
0
 def __init__(self, env, env_config, config):
     self.config = config
     self.config['env_config'] = env_config
     self.env = env(env_config)
     self.agent = MARWILTrainer(config=self.config, env=env)
Ejemplo n.º 4
0
PPO_agent.get_policy().config['explore'] = False

logdir = '/home/ash/ray_results/ssa_experiences/agent_visible_greedy_spoiled/' + str(
    env_config['rso_count']) + 'RSOs_jones_flatten_10000episodes/'

marwil_config = MARWIL_CONFIG.copy()
marwil_config['evaluation_num_workers'] = 1
marwil_config['env_config'] = env_config
marwil_config['evaluation_interval'] = 1
marwil_config['evaluation_config'] = {'input': 'sampler'}
marwil_config['beta'] = 1  # 0
marwil_config['input'] = logdir
marwil_config['env_config'] = env_config
marwil_config['explore'] = False

MARWIL_agent = MARWILTrainer(config=marwil_config, env=SSA_Tasker_Env)
MARWIL_agent.restore(marwil_checkpoint)
MARWIL_agent.get_policy().config['explore'] = False

pg_config = PG_CONFIG.copy()
pg_config['batch_mode'] = 'complete_episodes'
pg_config['train_batch_size'] = 2000
pg_config['lr'] = 0.0001
pg_config['evaluation_interval'] = None
pg_config['postprocess_inputs'] = True
pg_config['env_config'] = env_config
pg_config['explore'] = False

PGR_agent = PGTrainer(config=pg_config, env=SSA_Tasker_Env)
PGR_agent.restore(pgr_checkpoint)
PGR_agent.get_policy().config['explore'] = False
Ejemplo n.º 5
0
    env_config['rso_count']) + 'RSOs_jones_flatten_10000episodes/'
marwil_config['evaluation_num_workers'] = 1
marwil_config['env_config'] = env_config
marwil_config['evaluation_interval'] = 1
marwil_config['evaluation_config'] = {'input': 'sampler'}
marwil_config['beta'] = 1  # 0
marwil_config['input'] = logdir
'''if env_config['rso_count'] == 40:
    marwil_config['model']['fcnet_hiddens'] = [512, 512]'''

# !--- experiment
pg_config['lr'] = 1e-7
pg_config['train_batch_size'] = 4000
pg_config['num_workers'] = 8

MARWIL_trainer = MARWILTrainer(config=marwil_config, env=SSA_Tasker_Env)
PG_trainer = PGTrainer(config=pg_config, env=SSA_Tasker_Env)

if env_config['rso_count'] == 10:
    MARWIL_trainer.restore(
        '/home/ash/ray_results/MARWIL_SSA_Tasker_Env_2020-09-10_08-20-065a0phk5m/checkpoint_5000/checkpoint-5000'
    )  # 10 SSA Complete
elif env_config['rso_count'] == 20:
    MARWIL_trainer.restore(
        '/home/ash/ray_results/MARWIL_SSA_Tasker_Env_2020-09-10_08-24-14dv96mkld/checkpoint_5000/checkpoint-5000'
    )  # 20 SSA Complete
elif env_config['rso_count'] == 40:
    MARWIL_trainer.restore(
        '/home/ash/ray_results/MARWIL_SSA_Tasker_Env_2020-09-10_08-46-52wmigl7hj/checkpoint_10000/checkpoint-10000'
    )  # 40 SSA Complete
else:
Ejemplo n.º 6
0
    "input_evaluation": ["wis"],
    "evaluation_config": {"input": "sampler"},
    "beta": 1, #tune.grid_search([0, 1])
}

a3c_config = {
    "num_workers": 1,
    "gamma": 0.95,
}

if __name__ == "__main__":
    args = parser.parse_args()
    ray.init()

    register_env(
        "ExternalWorld",
        #lambda _: HeartsEnv()
        lambda _: ExternalWorld(NewsWorld(dict()), episodes=1000)
    )
    marwil = MARWILTrainer(config=marwil_config, env=ExternalWorld)

    i = 1
    while i < args.stop:
        result = marwil.train()
        print("Iteration {}, Episodes {}, Mean Reward {}, Mean Length {}".format(
            i, result['episodes_this_iter'], result['episode_reward_mean'], result['episode_len_mean']
        ))
        i += 1

    ray.shutdown()
Ejemplo n.º 7
0
logdir = '/home/ash/ray_results/ssa_experiences/' + agent + '/' + str(
    env_config['rso_count']) + 'RSOs_jones_flatten_10000episodes/'

config = DEFAULT_CONFIG.copy()
config['evaluation_num_workers'] = 10
config['env_config'] = env_config
config['evaluation_interval'] = 10
config['train_batch_size'] = 10000
'''if env_config['rso_count'] == 40:
    config['model']['fcnet_hiddens'] = [512, 512]'''
config['evaluation_config'] = {'input': 'sampler'}
config['beta'] = 1  # 0
config['input'] = logdir

trainer_MARWIL = MARWILTrainer(config=config, env=SSA_Tasker_Env)
best_athlete = 480
episode_len_mean = []
episode_reward_mean = []
episode_reward_max = []
episode_reward_min = []
start = datetime.datetime.now()
num_steps_trained = []
clock_time = []
training_iteration = []
for i in range(5000):
    # Perform one iteration of training the policy with DQN from offline data
    result_MARWIL = trainer_MARWIL.train()
    print(pretty_print(result_MARWIL))

    if result_MARWIL['training_iteration'] % config['evaluation_interval'] == 0: