def planned(cell, objective_fn, embedded, prev_action, planner, context=1, length=20, amount=1000, debug=False): use_obs = tf.ones(tf.shape(input=embedded[:, :context, :1])[:3], tf.bool) (_, closed_state), last_state = tf.compat.v1.nn.dynamic_rnn( cell, (embedded[:, :context], prev_action[:, :context], use_obs), dtype=tf.float32) _, plan_state, return_ = planner(cell, objective_fn, last_state, obs_shape=shape.shape(embedded)[2:], action_shape=shape.shape(prev_action)[2:], horizon=length, amount=amount) state = nested.map(lambda x, y: tf.concat([x, y], 1), closed_state, plan_state) if debug: with tf.control_dependencies([ tf.compat.v1.assert_equal( tf.shape(input=nested.flatten(state)[0])[1], context + length) ]): state = nested.map(tf.identity, state) return_ = tf.identity(return_) return state, return_
def perform(self, agent_indices, observ): observ = self._config.preprocess_fn(observ) embedded = self._config.encoder({'image': observ[:, None]})[:, 0] state = nested.map(lambda tensor: tf.gather(tensor, agent_indices), self._state) prev_action = self._prev_action + 0 with tf.control_dependencies([prev_action]): use_obs = tf.ones(tf.shape(input=agent_indices), tf.bool)[:, None] _, state = self._cell((embedded, prev_action, use_obs), state) action = self._config.planner(self._cell, self._config.objective, state, embedded.shape[1:].as_list(), prev_action.shape[1:].as_list()) action = action[:, 0] if self._config.exploration: scale = self._config.exploration.scale if self._config.exploration.schedule: scale *= self._config.exploration.schedule(self._step) action = tfd.Normal(action, scale).sample() action = tf.clip_by_value(action, -1, 1) remember_action = self._prev_action.assign(action) remember_state = nested.map(lambda var, val: tf.compat.v1. scatter_update(var, agent_indices, val), self._state, state, flatten=True) with tf.control_dependencies(remember_state + (remember_action, )): return tf.identity(action), tf.constant('')
def test_namedtuple(self): Foo = collections.namedtuple('Foo', 'value') foo, bar = [Foo(42)], [Foo(13)] function = nested.map(lambda x, y: (y, x), foo, bar) self.assertEqual([Foo((13, 42))], function) function = nested.map(lambda x, y: x + y, foo, bar) self.assertEqual([Foo(55)], function)
def chunk_sequence(sequence, chunk_length, randomize=True, num_chunks=None): """Split a nested dict of sequence tensors into a batch of chunks. This function does not expect a batch of sequences, but a single sequence. A `length` key is added if it did not exist already. When `randomize` is set, up to `chunk_length - 1` initial frames will be discarded. Final frames that do not fit into a chunk are always discarded. Args: sequence: Nested dict of tensors with time dimension. chunk_length: Size of chunks the sequence will be split into. randomize: Start chunking from a random offset in the sequence, enforcing that at least one chunk is generated. num_chunks: Optionally specify the exact number of chunks to be extracted from the sequence. Requires input to be long enough. Returns: Nested dict of sequence tensors with chunk dimension. """ with tf.device('/cpu:0'): if 'length' in sequence: # sequence = {'state': <tf.Tensor 'arg3:0' shape=(?, 1) dtype=float32>, 'image': <tf.Tensor 'arg1:0' shape=(?, 64, 64, 3) dtype=uint8>, 'action': <tf.Tensor 'arg0:0' shape=(?, 1) dtype=float32>, 'reward': <tf.Tensor 'arg2:0' shape=(?,) dtype=float32>} length = sequence.pop('length') else: length = tf.shape( nested.flatten(sequence)[0] )[0] # nested.flatten(): Combine all leaves of a nested structure into a tuple. if randomize: if num_chunks is None: num_chunks = tf.maximum(1, length // chunk_length - 1) else: num_chunks = num_chunks + 0 * length used_length = num_chunks * chunk_length max_offset = length - used_length # the episode length must >= chunk_length * num_chunks offset = tf.random_uniform( (), 0, max_offset + 1, dtype=tf.int32) # the starting point for clipping. else: if num_chunks is None: num_chunks = length // chunk_length else: num_chunks = num_chunks + 0 * length used_length = num_chunks * chunk_length max_offset = 0 offset = 0 clipped = nested.map( # nested.map(): Apply a function to every element in a nested structure. lambda tensor: tensor[offset:offset + used_length], sequence) chunks = nested.map( lambda tensor: tf.reshape( tensor, [num_chunks, chunk_length] + tensor.shape[1:].as_list( ) ), # reshape the clipped episode (num_chunks*chunk_length,64,64,3) into (num_chunks, chunk_length,64,64,3) clipped) chunks['length'] = chunk_length * tf.ones( (num_chunks, ), dtype=tf.int32) return chunks # shape(num_chunks, chunk_length,64,64,3)
def begin_episode(self, agent_indices): state = nested.map(lambda tensor: tf.gather(tensor, agent_indices), self._state) reset_state = nested.map(lambda var, val: tf.compat.v1.scatter_update( var, agent_indices, 0 * val), self._state, state, flatten=True) reset_prev_action = self._prev_action.assign( tf.zeros_like(self._prev_action)) with tf.control_dependencies(reset_state + (reset_prev_action, )): return tf.constant('')
def perform(self, agent_indices, observ): observ = self._config.preprocess_fn(observ) embedded = self._config.encoder({'image': observ[:, None]})[:, 0] state = nested.map(lambda tensor: tf.gather(tensor, agent_indices), self._state) prev_action = self._prev_action + 0 with tf.control_dependencies([prev_action]): use_obs = tf.ones(tf.shape(agent_indices), tf.bool)[:, None] _, state = self._cell((embedded, prev_action, use_obs), state) action = self._config.planner(self._cell, self._config.objective, state, embedded.shape[1:].as_list(), prev_action.shape[1:].as_list()) # First action is best action action = action[:, 0] # Random exploration noise (if exploring) if self._config.exploration: scale = self._config.exploration.scale if self._config.exploration.schedule: scale *= self._config.exploration.schedule(self._step) # Epsilon-greedy policy, with eps probability choose random action action_shape = action.shape # print(action.shape[0]) # print(action.shape[1]) # print(sf.shape) # input() action = tf.reshape( tf.cond( tf.random.uniform(shape=(), minval=0.0, maxval=1.0) < scale, lambda: tf.random.shuffle( tf.one_hot(indices=[0] * action.shape[0], depth=action.shape[1])), lambda: action), action_shape) # action = tfd.Normal(action, scale).sample() action = tf.clip_by_value(action, -1, 1) remember_action = self._prev_action.assign(action) remember_state = nested.map( lambda var, val: tf.scatter_update(var, agent_indices, val), self._state, state, flatten=True) with tf.control_dependencies(remember_state + (remember_action, )): return tf.identity(action), tf.constant('')
def chunk_sequence(sequence, chunk_length, randomize=True, num_chunks=None): """Split a nested dict of sequence tensors into a batch of chunks. This function does not expect a batch of sequences, but a single sequence. A `length` key is added if it did not exist already. When `randomize` is set, up to `chunk_length - 1` initial frames will be discarded. Final frames that do not fit into a chunk are always discarded. Args: sequence: Nested dict of tensors with time dimension. chunk_length: Size of chunks the sequence will be split into. randomize: Start chunking from a random offset in the sequence, enforcing that at least one chunk is generated. num_chunks: Optionally specify the exact number of chunks to be extracted from the sequence. Requires input to be long enough. Returns: Nested dict of sequence tensors with chunk dimension. """ with tf.device('/cpu:0'): if 'length' in sequence: length = sequence.pop('length') else: length = tf.shape(nested.flatten(sequence)[0])[0] if randomize: if num_chunks is None: num_chunks = tf.maximum(1, length // chunk_length - 1) else: num_chunks = num_chunks + 0 * length used_length = num_chunks * chunk_length max_offset = length - used_length # a=tf.print('dubuggg', max_offset,length,used_length) # with tf.control_dependencies([a]): # max_offset = max_offset *1 offset = tf.random_uniform((), 0, max_offset + 1, dtype=tf.int32) else: if num_chunks is None: num_chunks = length // chunk_length else: num_chunks = num_chunks + 0 * length used_length = num_chunks * chunk_length max_offset = 0 offset = 0 clipped = nested.map( lambda tensor: tensor[offset:offset + used_length], sequence) chunks = nested.map( lambda tensor: tf.reshape(tensor, [num_chunks, chunk_length] + tensor.shape[1:].as_list()), clipped) chunks['length'] = chunk_length * tf.ones( (num_chunks, ), dtype=tf.int32) return chunks
def closed_loop(cell, embedded, prev_action, debug=False): use_obs = tf.ones(tf.shape(input=embedded[:, :, :1])[:3], tf.bool) (prior, posterior), _ = tf.compat.v1.nn.dynamic_rnn( cell, (embedded, prev_action, use_obs), dtype=tf.float32) if debug: with tf.control_dependencies([ tf.compat.v1.assert_equal( tf.shape(input=nested.flatten(posterior)[0])[1], tf.shape(input=embedded)[1]) ]): prior = nested.map(tf.identity, prior) posterior = nested.map(tf.identity, posterior) return prior, posterior
def begin_episode(self, agent_indices): self._length = 0 a = tf.print('reset everything') r = tf.py_func(self.saver.reset, inp=[], Tout=tf.bool) state = nested.map(lambda tensor: tf.gather(tensor, agent_indices), self._state) reset_state = nested.map( lambda var, val: tf.scatter_update(var, agent_indices, 0 * val), self._state, state, flatten=True) reset_prev_action = self._prev_action.assign( tf.zeros_like(self._prev_action)) with tf.control_dependencies(reset_state + (reset_prev_action, r, a)): return tf.constant('')
def observation_space(self): return nested.map( lambda box: gym.spaces.Box( self._process_fn(box.low), self._process_fn(box.high), dtype=self._process_fn(box.low).dtype), self._env.observation_space)
def __init__(self, batch_env, step, is_training, should_log, config): self._batch_env = batch_env self._step = step # Trainer step, not environment step. self._is_training = is_training self._should_log = should_log self._config = config self._num_models = config.num_models #self._cell = config.cell #self._modelsampler = tfd.Uniform(low=0.0,high=2.0) #self._model = tf.dtypes.cast(self._modelsampler.sample(),tf.int32) self._cell = config.cell ### Initialize with the 0th model #state = self._cell[0].zero_state(len(batch_env), tf.float32) #Using a type of cell to init the state #print(state) var_like = lambda x: tf.get_local_variable( x.name.split(':')[0].replace('/', '_') + '_var', shape=x.shape, initializer=lambda *_, **__: tf.zeros_like(x), use_resource=True) self._state = [] for mdl in range(self._num_models): self._state.append(nested.map(var_like, self._cell[mdl].zero_state(len(batch_env), tf.float32))) # print('asdfasdf',self._state) #self._state = nested.map(var_like, state) self._prev_action = tf.get_local_variable( 'prev_action_var', shape=self._batch_env.action.shape, initializer=lambda *_, **__: tf.zeros_like(self._batch_env.action), use_resource=True)
def perform(self, agent_indices, observ, env_state=None): self._length = self._length + 1 if self._config.aug_fn is not None: print('augmented agent') observ = self._config.aug_fn({'image': observ}, phase='plan')['image'] observ = self._config.preprocess_fn(observ) embedded = self._config.encoder({'image': observ[:, None]})[:, 0] state = nested.map(lambda tensor: tf.gather(tensor, agent_indices), self._state) prev_action = self._prev_action + 0 with tf.control_dependencies([prev_action]): use_obs = tf.ones(tf.shape(agent_indices), tf.bool)[:, None] _, state = self._cell((embedded, prev_action, use_obs), state) # # a = tf.print(env_state) # with tf.control_dependencies([a]): # prev_action = self._prev_action + 0 action = self._config.planner(self._cell, self._config.objective, state, embedded.shape[1:].as_list(), prev_action.shape[1:].as_list(), env_state=env_state) # (1, 12, 2) action = action[:, 0] if self._config.exploration: scale = self._config.exploration.scale if self._config.exploration.schedule: scale *= self._config.exploration.schedule(self._step) action = tfd.Normal(action, scale).sample() action = tf.clip_by_value(action, -1, 1) # a = tf.print('action ', tf.reduce_max(action), tf.reduce_min(action)) # with tf.control_dependencies([a]): remember_action = self._prev_action.assign(action) remember_state = nested.map( lambda var, val: tf.scatter_update(var, agent_indices, val), self._state, state, flatten=True) with tf.control_dependencies(remember_state + (remember_action, )): return tf.identity(action), tf.constant('')
def perform(self, agent_indices, observ, env_state=None): observ, action, all_action, collect, all_reward = self.rival.get_next() print('wtf', observ, action, all_action, collect, all_reward) observ, all_action, all_reward = tf.squeeze(observ, 0), tf.squeeze( all_action, 0), tf.squeeze(all_reward, 0) self._length = self._length + 1 if self._config.aug_fn is not None: print('augmented agent') observ = self._config.aug_fn({'image': observ}, phase='plan')['image'] observ = self._config.preprocess_fn(observ) embedded = self._config.encoder({'image': observ[:, None]})[:, 0] state = nested.map(lambda tensor: tf.gather(tensor, agent_indices), self._state) prev_action = self._prev_action + 0 with tf.control_dependencies([prev_action]): use_obs = tf.ones(tf.shape(agent_indices), tf.bool)[:, None] _, state = self._cell((embedded, prev_action, use_obs), state) triple = self._config.planner(self._cell, self._config.objective, state, all_action, all_reward, embedded.shape[1:].as_list(), prev_action.shape[1:].as_list(), env_state=env_state) save = tf.cond(collect[0][0], lambda: self.gd_save(triple), lambda: tf.no_op()) with tf.control_dependencies([save]): action = action[0] # a = tf.print('action ', tf.reduce_max(action), tf.reduce_min(action)) # with tf.control_dependencies([a]): remember_action = self._prev_action.assign(action) remember_state = nested.map( lambda var, val: tf.scatter_update(var, agent_indices, val), self._state, state, flatten=True) with tf.control_dependencies(remember_state + (remember_action, )): return tf.identity(action), tf.constant('')
def begin_episode(self, agent_indices): reset_state = [] for mdl in range(self._num_models): state = nested.map( lambda tensor: tf.gather(tensor, agent_indices), self._state[mdl]) reset_state.append(nested.map( lambda var, val: tf.scatter_update(var, agent_indices, 0 * val), self._state[mdl], state, flatten=True)) reset_prev_action = self._prev_action.assign( tf.zeros_like(self._prev_action)) controldep = reset_state[0] for mdl in range(1,self._num_models): controldep += reset_state[mdl] with tf.control_dependencies(controldep + (reset_prev_action,)): return tf.constant('')
def open_loop(cell, embedded, prev_action, context=1, debug=False): use_obs = tf.ones(tf.shape(embedded[:, :context, :1])[:3], tf.bool) (_, closed_state), last_state = tf.compat.v1.nn.dynamic_rnn( cell, (embedded[:, :context], prev_action[:, :context], use_obs), dtype=tf.float32) use_obs = tf.zeros(tf.shape(input=embedded[:, context:, :1])[:3], tf.bool) (_, open_state), _ = tf.compat.v1.nn.dynamic_rnn( cell, (0 * embedded[:, context:], prev_action[:, context:], use_obs), initial_state=last_state) state = nested.map(lambda x, y: tf.concat([x, y], 1), closed_state, open_state) if debug: with tf.control_dependencies([ tf.compat.v1.assert_equal( tf.shape(input=nested.flatten(state)[0])[1], tf.shape(input=embedded)[1]) ]): state = nested.map(tf.identity, state) return state
def reshape_as(tensor, reference): if isinstance(tensor, (list, tuple, dict)): return nested.map(tensor, lambda x: reshape_as(x, reference)) tensor = tf.convert_to_tensor(tensor) reference = tf.convert_to_tensor(reference) statics = reference.shape.as_list() dynamics = tf.shape(reference) shape = [ static if static is not None else dynamics[index] for index, static in enumerate(statics) ] return tf.reshape(tensor, shape)
def perform(self, agent_indices, observ): observ = self._config.preprocess_fn(observ) # Adds sequence dimension to observation tensor and then discards it embedded = self._config.encoder({'image': observ[:, None]})[:, 0] state = nested.map(lambda tensor: tf.gather(tensor, agent_indices), self._state) prev_action = self._prev_action + 0 with tf.control_dependencies([prev_action]): use_obs = tf.ones(tf.shape(agent_indices), tf.bool)[:, None] _, state = self._cell((embedded, prev_action, use_obs), state) action = self._config.planner(self._cell, self._config.objective, state, embedded.shape[1:].as_list(), prev_action.shape[1:].as_list()) if self._config.exploration: scale = self._config.exploration.scale if self._config.exploration.schedule: scale *= self._config.exploration.schedule(self._step) if self._config.discrete_action: tf.logging.info("Exploration using e-greedy policy.") action_num = self._batch_env.action.shape[1].value # e-greedy policy probs = tf.ones_like(action) * scale / action_num probs += (1.0 - scale) * action indices = tfd.Categorical(probs=probs).sample() action = tf.one_hot(indices, depth=action_num, dtype=tf.float32) else: tf.logging.info("Exploration using random noise.") # Add random noise to continuous action action = tfd.Normal(action, scale).sample() action = tf.clip_by_value(action, -1, 1) remember_action = self._prev_action.assign(action) remember_state = nested.map( lambda var, val: tf.scatter_update(var, agent_indices, val), self._state, state, flatten=True) with tf.control_dependencies(remember_state + (remember_action, )): return tf.identity(action), tf.constant('')
def __init__(self, batch_env, step, is_training, should_log, config): self._batch_env = batch_env self._step = step # Trainer step, not environment step. self._is_training = is_training self._should_log = should_log self._config = config self._cell = config.cell state = self._cell.zero_state(len(batch_env), tf.float32) # var_like = lambda x: tf.compat.v1.get_local_variable( # x.name.split(':')[0].replace('/', '_') + '_var', # shape=x.shape, # initializer=lambda *_, **__: tf.compat.v1.zeros_like(x), use_resource=True) self._state = nested.map(var_like, state) self._prev_action = tf.compat.v1.get_local_variable( 'prev_action_var', shape=self._batch_env.action.shape, initializer=lambda *_, **__: tf.compat.v1.zeros_like( self._batch_env.action), use_resource=True)
def _merge_dims(tensor, dims): """Flatten consecutive axes of a tensor trying to preserve static shapes.""" if isinstance(tensor, (list, tuple, dict)): return nested.map(tensor, lambda x: _merge_dims(x, dims)) tensor = tf.convert_to_tensor(tensor) if (np.array(dims) - min(dims) != np.arange(len(dims))).all(): raise ValueError('Dimensions to merge must all follow each other.') start, end = dims[0], dims[-1] output = tf.reshape( tensor, tf.concat([ tf.shape(tensor)[:start], [tf.reduce_prod(tf.shape(tensor)[start:end + 1])], tf.shape(tensor)[end + 1:] ], axis=0)) merged = tensor.shape[start:end + 1].as_list() output.set_shape(tensor.shape[:start].as_list() + [None if None in merged else np.prod(merged)] + tensor.shape[end + 1:].as_list()) return output
def __init__(self, batch_env, step, is_training, should_log, config): self._batch_env = batch_env self._step = step # Trainer step, not environment step. self._is_training = is_training self._should_log = should_log self._config = config self._cell = config.cell self._length = 0 self.num_episodes = 0 self.saver = rolloutSaver(os.path.join(config.logdir, 'rollout')) state = self._cell.zero_state(len(batch_env), tf.float32) var_like = lambda x: tf.get_local_variable( x.name.split(':')[0].replace('/', '_') + '_var', shape=x.shape, initializer=lambda *_, **__: tf.zeros_like(x), use_resource=True) self._state = nested.map(var_like, state) self._prev_action = tf.get_local_variable( 'prev_action_var', shape=self._batch_env.action.shape, initializer=lambda *_, **__: tf.zeros_like(self._batch_env.action), use_resource=True)
def _merge_dims(tensor, dims): # tensor: shape(51,40,50,1024) """Flatten consecutive axes of a tensor trying to preserve static shapes.""" if isinstance(tensor, (list, tuple, dict)): return nested.map(tensor, lambda x: _merge_dims(x, dims)) tensor = tf.convert_to_tensor(tensor) if (np.array(dims) - min(dims) != np.arange(len(dims))).all(): raise ValueError('Dimensions to merge must all follow each other.') start, end = dims[0], dims[-1] # start, end = 1, 2 output = tf.reshape( tensor, tf.concat( [ # tf.reshape(tensor, shape[51,2000,1024]) tf.shape(tensor) [:start], # [51,] # tf.shape(tensor):(51,40,50,1024) [tf.reduce_prod(tf.shape(tensor)[start:end + 1])], # [40*50,] tf.shape(tensor)[end + 1:] ], axis=0)) # [1024,] merged = tensor.shape[start:end + 1].as_list() # [40,50] output.set_shape(tensor.shape[:start].as_list() + # [51]+ [None if None in merged else np.prod(merged)] + # [2000]+ tensor.shape[end + 1:].as_list()) # [1024] return output
def __init__(self, batch_env, step, is_training, should_log, config): print('mpc agent dual2') self._batch_env = batch_env self._step = step # Trainer step, not environment step. self._is_training = is_training self._should_log = should_log self._config = config self._cell = config.cell self._length = 0 self.logdir = config.logdir self.rival_dir = os.path.join('benchmark', config.rival, 'rollout') state = self._cell.zero_state(len(batch_env), tf.float32) var_like = lambda x: tf.get_local_variable( x.name.split(':')[0].replace('/', '_') + '_var', shape=x.shape, initializer=lambda *_, **__: tf.zeros_like(x), use_resource=True) self._state = nested.map(var_like, state) self._prev_action = tf.get_local_variable( 'prev_action_var', shape=self._batch_env.action.shape, initializer=lambda *_, **__: tf.zeros_like(self._batch_env.action), use_resource=True) import functools dtypes, shapes = numpy_episodes._read_spec2( numpy_episodes.episode_reader, self.rival_dir) rival = tf.data.Dataset.from_generator( functools.partial(gener, self.rival_dir), dtypes, shapes) def wh(sq): print('sq', sq['observ']) return sq['observ'], sq['action'], sq['all_action'], sq[ 'collect'], sq['all_reward'] rival = rival.map(wh) rival = rival.batch(1) self.rival = rival.make_one_shot_iterator()
def test_shallow_list(self): self.assertEqual([2, 4, 6], nested.map(lambda x: 2 * x, [1, 2, 3]))
def test_mixed_types(self): self.assertEqual([14, 'foofoo'], nested.map(lambda x: x * 2, [7, 'foo']))
def test_scalar(self): self.assertEqual(42, nested.map(lambda x: x, 42))
def step(self, action): observ, reward, done, info = self._env.step(action) observ = nested.map(self._convert_observ, observ) reward = self._convert_reward(reward) return observ, reward, done, info
def overshooting(cell, target, embedded, prev_action, length, amount, ignore_input=False): """Perform open loop rollouts from the posteriors at every step. First, we apply the encoder to embed raw inputs and apply the model to obtain posterior states for every time step. Then, we perform `amount` long open loop rollouts from these posteriors. Note that the actions should be those leading to the current time step. So under common convention, it contains the last actions while observations are the current ones. Input: target, embedded: [A B C D E F] [A B C D E ] prev_action: [0 A B C D E] [0 A B C D ] length: [6 5] amount: 3 Output: prior, posterior, target: [A B C D E F] [A B C D E ] [B C D E F ] [B C D E ] [C D E F ] [C D E ] [D E F ] [D E ] mask: [1 1 1 1 1 1] [1 1 1 1 1 0] [1 1 1 1 1 0] [1 1 1 1 0 0] [1 1 1 1 0 0] [1 1 1 0 0 0] [1 1 1 0 0 0] [1 1 0 0 0 0] """ # Closed loop unroll to get posterior states, which are the starting points # for open loop unrolls. We don't need the last time step, since we have no # targets for unrolls from it. use_obs = tf.ones( tf.shape(nested.flatten(embedded)[0][:, :, :1])[:3], tf.bool) use_obs = tf.cond(tf.convert_to_tensor(ignore_input), lambda: tf.zeros_like(use_obs, tf.bool), lambda: use_obs) initial_state = { 'mean': tf.zeros((int(use_obs.shape[0]), cell._state_size)), 'stddev': tf.zeros((int(use_obs.shape[0]), cell._state_size)), 'sample': tf.zeros((int(use_obs.shape[0]), cell._state_size)), 'belief': tf.zeros((int(use_obs.shape[0]), cell._state_size)), 'rnn_state': cell._cell.initial_state(int(use_obs.shape[0])), } (prior, posterior), _ = tf.nn.dynamic_rnn(cell, (embedded, prev_action, use_obs), length, dtype=tf.float32, initial_state=initial_state, swap_memory=False, time_major=True) # Arrange inputs for every iteration in the open loop unroll. Every loop # iteration below corresponds to one row in the docstring illustration. max_length = shape.shape(nested.flatten(embedded)[0])[1] first_output = { 'observ': embedded, 'prev_action': prev_action, 'posterior': posterior, 'target': target, 'mask': tf.sequence_mask(length, max_length, tf.int32), } progress_fn = lambda tensor: tf.concat([tensor[:, 1:], 0 * tensor[:, :1]], 1) other_outputs = tf.scan( lambda past_output, _: nested.map(progress_fn, past_output), tf.range(amount), first_output) sequences = nested.map(lambda lhs, rhs: tf.concat([lhs[None], rhs], 0), first_output, other_outputs) # Merge batch and time dimensions of steps to compute unrolls from every # time step as one batch. The time dimension becomes the number of # overshooting distances. sequences = nested.map(lambda tensor: _merge_dims(tensor, [1, 2]), sequences) sequences = nested.map( lambda tensor: tf.transpose(tensor, [1, 0] + list( range(2, tensor.shape.ndims))), sequences) merged_length = tf.reduce_sum(sequences['mask'], 1) # Mask out padding frames; unnecessary if the input is already masked. sequences = nested.map( lambda tensor: tensor * tf.cast( _pad_dims(sequences['mask'], tensor.shape.ndims), tensor.dtype), sequences) # Compute open loop rollouts. use_obs = tf.zeros(tf.shape(sequences['mask']), tf.bool)[..., None] prev_state = nested.map( lambda tensor: tf.concat([0 * tensor[:, :1], tensor[:, :-1]], 1), posterior) prev_state = nested.map(lambda tensor: _merge_dims(tensor, [0, 1]), prev_state) (priors, _), _ = tf.nn.dynamic_rnn( cell, (sequences['observ'], sequences['prev_action'], use_obs), merged_length, prev_state) # Restore batch dimension. target, prior, posterior, mask = nested.map( functools.partial(_restore_batch_dim, batch_size=shape.shape(length)[0]), (sequences['target'], priors, sequences['posterior'], sequences['mask'])) mask = tf.cast(mask, tf.bool) return target, prior, posterior, mask
def reset(self): observ = self._env.reset() observ = nested.map(self._convert_observ, observ) return observ
def test_multiple_lists(self): a = [1, 2, 3] b = [4, 5, 6] c = [7, 8, 9] result = nested.map(lambda x, y, z: x + y + z, a, b, c) self.assertEqual([12, 15, 18], result)
def test_empty(self): self.assertEqual({}, nested.map(lambda x: x, {}))