class ChildrenValuePrinter(HumanPrintWrapper): def __init__(self, env, value_fun): """ Args: value_fun: callable: obs, states -> value, which would be call by key `states` """ super().__init__(env) self.render_env = SokobanEnv(**env.init_kwargs) self.value_fun = value_fun def formatted_state_value(self, state): return "{:.2f}".format(self.value_fun(states=state)[0][0]) def build_texts(self, obs, reward, done, info): child_values = list() state = self.env.clone_full_state() value_str = self.formatted_state_value(state) for action in range(self.render_env.action_space.n): self.render_env.restore_full_state(state) self.render_env.step(action) child_state = self.render_env.clone_full_state() child_value_str = self.formatted_state_value(child_state) child_values.append(child_value_str) print('Children values: {}'.format(" ".join(child_values))) return [ 'Value: {}'.format(value_str), 'Children values: {}'.format(" ".join(child_values)) ]
def _load_shard_vf(shard, data_files_prefix, env_kwargs, filter_values_fn=None, transform_values_fn=None): data = _load_shard(shard, data_files_prefix) render_env = SokobanEnv(**env_kwargs) data_x = [] data_y = [] vf = ValueLoader() for vf_for_root in data: root = vf.load_vf_for_root(vf_for_root, compressed=True) data = vf.dump_vf_for_root(root) for env_state, v in data: if filter_values_fn: if filter_values_fn(v): continue if transform_values_fn: v = transform_values_fn(v) render_env.restore_full_state(env_state) ob = render_env.render(mode=render_env.mode) data_x.append(ob) data_y.append(v) data_y = np.asarray(data_y) if len(data_y.shape) == 1: data_y = data_y.reshape((len(data_y), 1)) return np.asarray(data_x), data_y, {}
class PolicyFromFullTree(Policy): def __init__(self, value_fn, env_kwargs, depth=4): self.render_env = SokobanEnv(**env_kwargs) self.env_n_actions = self.render_env.action_space.n self.value_function = value_fn self.env = SokobanEnv(**env_kwargs) self.env.reset() self.depth = depth self.nodes = dict() def best_actions(self, state): # Produce all action sequences seq_ = [range(self.env.action_space.n)] * self.depth action_seq = list(product(*seq_)) # print("len(action_seq) {}".format(len(action_seq))) for actions in action_seq: root_action = actions[0] self.env.restore_full_state(state) branch_reward = 0 current_depth = 0 for action in actions: current_depth += 1 ob, reward, done, _ = self.env.step(action) branch_reward += reward node = tuple(self.env.clone_full_state()) if node not in self.nodes: value = self.value_function( states=np.array(node) ) # self.model.predict(np.expand_dims(ob, axis=0))[0] if done: value += 1000 self.nodes[node] = (value, branch_reward, current_depth, root_action, actions[:current_depth]) else: value, previous_reward, previous_depth, _, _ = self.nodes[ node] if previous_depth > current_depth: # if previous_reward > branch_reward: # assert branch_reward > 10., "{} {}".format(previous_reward, branch_reward) self.nodes[node] = (value, branch_reward, current_depth, root_action, actions[:current_depth]) if done: break # self.nodes.values() best_node = max( self.nodes.keys(), key=(lambda node: self.nodes[node][0] + self.nodes[node][1])) node_value, branch_reward, current_depth, root_action, actions = self.nodes[ best_node] # print("Distinct leaves {}".format(len(self.nodes))) # print("Node value {}, reward {:.1f}, depth {}, action {}, actions {}".format( # node_value, branch_reward, current_depth, root_action, actions)) return [root_action]
class ValueFromKerasNet(Value, ABC): def __init__(self, model, env_kwargs): if isinstance(model, str): self.model = load_model(model) else: self.model = model self.env = SokobanEnv(**env_kwargs) self.env.reset() def _network_prediction(self, state): self.env.restore_full_state(state) obs = self.env.render() return self.model.predict(np.expand_dims(obs, axis=0)) def __call__(self, state): raise NotImplementedError
class PolicyFromNet(Policy): def __init__(self, model, env_kwargs): self.render_env = SokobanEnv(**env_kwargs) self.env_n_actions = self.render_env.action_space.n if isinstance(model, str): self.model = load_model(model) else: self.model = model self.env = SokobanEnv(**env_kwargs) self.env.reset() assert len(self.model.outputs) == 1 def best_actions(self, state): self.env.restore_full_state(state) ob = self.env.render() policy = self.model.predict(np.expand_dims(ob, axis=0))[0] best_actions = [np.argmax(policy)] return best_actions
class QFromV(object): def __init__(self, value_function, env_kwargs, nan_for_zero_value=True, copy_negative=True): self.value_function = value_function self.env = SokobanEnv(**env_kwargs) self.env.reset() self.nan_for_zero_value = nan_for_zero_value self.copy_negative_values = copy_negative @property def env_n_actions(self): return self.env.action_space.n def q_values(self, state): q_values = list() if self.nan_for_zero_value: # Value might not have children for Sokoban success states. if self.value_function(states=state) == 0: return [np.nan] * self.env_n_actions if self.copy_negative_values: # For speed-up val = self.value_function(states=state)[0] if val < 0: return [val] * self.env_n_actions for action in range(self.env_n_actions): self.env.restore_full_state(state) ob, reward, done, _ = self.env.step(action) value = reward child_state = self.env.clone_full_state() if not done: value += self.value_function(states=child_state)[0] q_values.append(float(value)) return q_values
def _load_shard_best_action_ignore_finall(shard, data_files_prefix, env_kwargs): """ Choose best action If all actions are equally good, give special target value (equal to env.action_space.n). For Sokoban this will separate dead ends. (for which there is no good action). """ boards = _load_shard(shard, data_files_prefix) render_env = SokobanEnv(**env_kwargs) data_x = [] data_y = [] data_value = [] vf = ValueLoader() policy = PolicyFromValue(vf, env_kwargs) assert policy.env_n_actions == render_env.action_space.n for vf_for_root in boards: root = vf.load_vf_for_root(vf_for_root, compressed=True) data = vf.dump_vf_for_root(root) for node_state, v in data: if v in [0, -float("inf")]: # TODO(kc): ValuePerfect does not produce some states which can be # obtained after solving game. How to clean it up? continue render_env.restore_full_state(node_state) ob = render_env.render(mode=render_env.mode) data_x.append(ob) best_actions = policy.act(node_state, return_single_action=False) y = np.min(best_actions) one_hot_y = np.zeros(shape=render_env.action_space.n, dtype=np.int) one_hot_y[y] = 1 data_y.append(one_hot_y) data_value.append(v) return np.asarray(data_x), np.asarray(data_y), \ dict(value=np.asarray(data_value))
def process_board_data(compressed_data, target, env_kwargs, sample_data, max_sample_size, random_state): """ Args: compressed_data: dictionary with keys containing ["full_env_state", "perfect_value", "perfect_q"], mapping to compressed arrays. """ render_env = SokobanEnv(**env_kwargs) keys = compressed_data.keys() assert_v2_keys(compressed_data) data = {key: decompress_np_array(compressed_data[key]) for key in keys} assert_env_and_state_match(env_kwargs, data["full_env_state"][0]) filter_values_fn = lambda v, q: False stratified_sample_fn = lambda values, q: stratified_sample( values, q, max_sample_size, random_state) simple_sample_fn = lambda values, q: simple_sample( values, q, max_sample_size, random_state) if target == Target.VF: sample_fn = stratified_sample_fn elif target == Target.VF_SOLVABLE_ONLY: filter_values_fn = lambda v, q: not is_solvable_state(v, q) sample_fn = simple_sample_fn elif target == Target.STATE_TYPE: sample_fn = stratified_sample_fn elif target == Target.BEST_ACTION: filter_values_fn = lambda v, q: not is_solvable_state(v, q) sample_fn = simple_sample_fn elif target == Target.VF_AND_TYPE: sample_fn = stratified_sample_fn elif target == Target.NEXT_FRAME: sample_fn = stratified_sample_fn elif target == Target.DELTA_VALUE: sample_fn = stratified_sample_fn elif target == Target.VF_DISCOUNTED: sample_fn = stratified_sample_fn elif target == Target.BEST_ACTION_FRAMESTACK: filter_values_fn = lambda v, q: not is_solvable_state(v, q) sample_fn = simple_sample_fn elif target == Target.NEXT_FRAME_AND_DONE: sample_fn = stratified_sample_fn else: raise ValueError("Unknown target {}".format(target)) mask = ~np.array([ filter_values_fn(v, q) for v, q in zip(data['perfect_value'], data['perfect_q']) ], dtype=np.bool) data = {key: data[key][mask] for key in keys} if sample_data: sample_ix = sample_fn(data["perfect_value"], data["perfect_q"]) else: raise NotImplemented() if target == Target.DELTA_VALUE: data_x, data_y = extract_delta_value(data, sample_ix, render_env, random_state) elif target == Target.VF_DISCOUNTED: data_x, data_y = extract_discounted_value( sample_ix, states=data["full_env_state"], perfect_v=data["perfect_value"], perfect_q=data["perfect_q"], render_env=render_env, ) elif target == Target.BEST_ACTION_FRAMESTACK: data_x, data_y = extract_best_action_from_framestack( sample_ix, states=data["full_env_state"], perfect_v=data["perfect_value"], perfect_q=data["perfect_q"], render_env=render_env, ) else: data = {key: data[key][sample_ix] for key in keys} if target == Target.NEXT_FRAME: data_x, data_y = extract_next_frame_input_and_target( data["full_env_state"], render_env) else: obs = list() for node_state in data['full_env_state']: render_env.restore_full_state(node_state) ob = render_env.render(mode=render_env.mode) obs.append(ob) data_x = np.array(obs) data_y = extract_target_from_value(perfect_v=data["perfect_value"], perfect_q=data["perfect_q"], target=target) if isinstance(data_y, np.ndarray): assert len(data_y.shape) > 1, "data_y should be batched (if target is " \ "scalar it should have shape (num_samples, 1))" return data_x, data_y, {}