def test_namedtuple(self): Foo = collections.namedtuple('Foo', 'value') foo, bar = [Foo(42)], [Foo(13)] function = nested.map(lambda x, y: (y, x), foo, bar) self.assertEqual([Foo((13, 42))], function) function = nested.map(lambda x, y: x + y, foo, bar) self.assertEqual([Foo(55)], function)
def chunk_sequence(sequence, chunk_length, randomize=True, num_chunks=None): if 'length' in sequence: length = sequence.pop('length') else: length = tf.shape(nested.flatten(sequence)[0])[0] if randomize: if not num_chunks: num_chunks = tf.maximum(1, length // chunk_length - 1) else: num_chunks = num_chunks + 0 * length used_length = num_chunks * chunk_length max_offset = length - used_length offset = tf.random_uniform((), 0, max_offset + 1, dtype=tf.int32) else: if num_chunks is None: num_chunks = length // chunk_length else: num_chunks = num_chunks + 0 * length used_length = num_chunks * chunk_length offset = 0 clipped = nested.map(lambda tensor: tensor[offset:offset + used_length], sequence) chunks = nested.map( lambda tensor: tf.reshape(tensor, [num_chunks, chunk_length] + tensor. shape[1:].as_list()), clipped) return chunks
def closed_loop(cell, embedded, prev_action, debug=False): use_obs = tf.ones(tf.shape(embedded[:, :, :1])[:3], tf.bool) (prior, posterior), _ = tf.nn.dynamic_rnn( cell, (embedded, prev_action, use_obs), dtype=tf.float32) if debug: with tf.control_dependencies([tf.assert_equal( tf.shape(nested.flatten(posterior)[0])[1], tf.shape(embedded)[1])]): prior = nested.map(tf.identity, prior) posterior = nested.map(tf.identity, posterior) return prior, posterior
def reset(self, agent_indices): state = nested.map(lambda tensor: tf.gather(tensor, agent_indices), self._state) reset_state = nested.map( lambda var, val: tf.scatter_update(var, agent_indices, 0 * val), self._state, state, flatten=True) reset_prev_action = self._prev_action.assign( tf.zeros_like(self._prev_action)) return tf.group(reset_prev_action, *reset_state)
def open_loop(cell, embedded, prev_action, context=1, debug=False): use_obs = tf.ones(tf.shape(embedded[:, :context, :1])[:3], tf.bool) (_, closed_state), last_state = tf.nn.dynamic_rnn( cell, (embedded[:, :context], prev_action[:, :context], use_obs), dtype=tf.float32) use_obs = tf.zeros(tf.shape(embedded[:, context:, :1])[:3], tf.bool) (_, open_state), _ = tf.nn.dynamic_rnn( cell, (0 * embedded[:, context:], prev_action[:, context:], use_obs), initial_state=last_state) state = nested.map( lambda x, y: tf.concat([x, y], 1), closed_state, open_state) if debug: with tf.control_dependencies([tf.assert_equal( tf.shape(nested.flatten(state)[0])[1], tf.shape(embedded)[1])]): state = nested.map(tf.identity, state) return state
def step(self, agent_indices, observ): observ = self._config.preprocess_fn(observ) # Converts observ to sequence. observ = nested.map(lambda x: x[:, None], observ) embedded = self._config.encoder(observ)[:, 0] state = nested.map(lambda tensor: tf.gather(tensor, agent_indices), self._state) prev_action = self._prev_action + 0 with tf.control_dependencies([prev_action]): use_obs = tf.ones(tf.shape(agent_indices), tf.bool)[:, None] _, state = self._cell((embedded, prev_action, use_obs), state) action = self._config.planner(self._cell, self._config.objective, state, embedded.shape[1:].as_list(), prev_action.shape[1:].as_list()) action = action[:, 0] if self._config.exploration: expl = self._config.exploration scale = tf.cast(expl.scale, tf.float32)[None] # Batch dimension. if expl.schedule: scale *= expl.schedule(self._step) if expl.factors: scale *= np.array(expl.factors) if expl.type == 'additive_normal': action = tfd.Normal(action, scale[:, None]).sample() elif expl.type == 'epsilon_greedy': random_action = tf.one_hot( tfd.Categorical(0 * action).sample(), action.shape[-1]) switch = tf.cast( tf.less(tf.random.uniform((self._num_envs, )), scale), tf.float32)[:, None] action = switch * random_action + (1 - switch) * action else: raise NotImplementedError(expl.type) action = tf.clip_by_value(action, -1, 1) remember_action = self._prev_action.assign(action) remember_state = nested.map( lambda var, val: tf.scatter_update(var, agent_indices, val), self._state, state, flatten=True) with tf.control_dependencies(remember_state + (remember_action, )): return tf.identity(action)
def planned( cell, objective_fn, embedded, prev_action, planner, context=1, length=20, amount=1000, debug=False): use_obs = tf.ones(tf.shape(embedded[:, :context, :1])[:3], tf.bool) (_, closed_state), last_state = tf.nn.dynamic_rnn( cell, (embedded[:, :context], prev_action[:, :context], use_obs), dtype=tf.float32) _, plan_state, return_ = planner( cell, objective_fn, last_state, obs_shape=shape.shape(embedded)[2:], action_shape=shape.shape(prev_action)[2:], horizon=length, amount=amount) state = nested.map( lambda x, y: tf.concat([x, y], 1), closed_state, plan_state) if debug: with tf.control_dependencies([tf.assert_equal( tf.shape(nested.flatten(state)[0])[1], context + length)]): state = nested.map(tf.identity, state) return_ = tf.identity(return_) return state, return_
def __init__(self, batch_env, step, is_training, should_log, config): self._step = step # Trainer step, not environment step. self._is_training = is_training self._should_log = should_log self._config = config self._cell = config.cell self._num_envs = len(batch_env) state = self._cell.zero_state(self._num_envs, tf.float32) var_like = lambda x: tf.get_local_variable( x.name.split(':')[0].replace('/', '_') + '_var', shape=x.shape, initializer=lambda *_, **__: tf.zeros_like(x), use_resource=True) self._state = nested.map(var_like, state) batch_action_shape = (self._num_envs, ) + batch_env.action_space.shape self._prev_action = tf.get_local_variable( 'prev_action_var', shape=batch_action_shape, initializer=lambda *_, **__: tf.zeros(batch_action_shape), use_resource=True)
def _merge_dims(tensor, dims): if isinstance(tensor, (list, tuple, dict)): return nested.map(tensor, lambda x: _merge_dims(x, dims)) tensor = tf.convert_to_tensor(tensor) if (np.array(dims) - min(dims) != np.arange(len(dims))).all(): raise ValueError('Dimensions to merge must all follow each other.') start, end = dims[0], dims[-1] output = tf.reshape( tensor, tf.concat([ tf.shape(tensor)[:start], [tf.reduce_prod(tf.shape(tensor)[start:end + 1])], tf.shape(tensor)[end + 1:] ], axis=0)) merged = tensor.shape[start:end + 1].as_list() output.set_shape(tensor.shape[:start].as_list() + [None if None in merged else np.prod(merged)] + tensor.shape[end + 1:].as_list()) return output
def test_shallow_list(self): self.assertEqual([2, 4, 6], nested.map(lambda x: 2 * x, [1, 2, 3]))
def test_empty(self): self.assertEqual({}, nested.map(lambda x: x, {}))
def test_scalar(self): self.assertEqual(42, nested.map(lambda x: x, 42))
def test_multiple_lists(self): a = [1, 2, 3] b = [4, 5, 6] c = [7, 8, 9] result = nested.map(lambda x, y, z: x + y + z, a, b, c) self.assertEqual([12, 15, 18], result)
def test_mixed_types(self): self.assertEqual([14, 'foofoo'], nested.map(lambda x: x * 2, [7, 'foo']))
def test_mixed_structure(self): structure = [(1, 2), 3, {'foo': [4]}] result = nested.map(lambda x: 2 * x, structure) self.assertEqual([(2, 4), 6, {'foo': [8]}], result)
def test_shallow_dict(self): data = {'a': 1, 'b': 2, 'c': 3, 'd': 4} self.assertEqual(data, nested.map(lambda x: x, data))
def overshooting(cell, target, embedded, prev_action, length, amount, posterior=None, ignore_input=False): # Closed loop unroll to get posterior states, which are the starting points # for open loop unrolls. We don't need the last time step, since we have no # targets for unrolls from it. if posterior is None: use_obs = tf.ones( tf.shape(nested.flatten(embedded)[0][:, :, :1])[:3], tf.bool) use_obs = tf.cond(tf.convert_to_tensor(ignore_input), lambda: tf.zeros_like(use_obs, tf.bool), lambda: use_obs) (_, posterior), _ = tf.nn.dynamic_rnn(cell, (embedded, prev_action, use_obs), length, dtype=tf.float32, swap_memory=True) # Arrange inputs for every iteration in the open loop unroll. Every loop # iteration below corresponds to one row in the docstring illustration. max_length = shape.shape(nested.flatten(embedded)[0])[1] first_output = { # 'observ': embedded, 'prev_action': prev_action, 'posterior': posterior, 'target': target, 'mask': tf.sequence_mask(length, max_length, tf.int32), } progress_fn = lambda tensor: tf.concat([tensor[:, 1:], 0 * tensor[:, :1]], 1) other_outputs = tf.scan( lambda past_output, _: nested.map(progress_fn, past_output), tf.range(amount), first_output) sequences = nested.map(lambda lhs, rhs: tf.concat([lhs[None], rhs], 0), first_output, other_outputs) # Merge batch and time dimensions of steps to compute unrolls from every # time step as one batch. The time dimension becomes the number of # overshooting distances. sequences = nested.map(lambda tensor: _merge_dims(tensor, [1, 2]), sequences) sequences = nested.map( lambda tensor: tf.transpose(tensor, [1, 0] + list( range(2, tensor.shape.ndims))), sequences) merged_length = tf.reduce_sum(sequences['mask'], 1) # Mask out padding frames; unnecessary if the input is already masked. sequences = nested.map( lambda tensor: tensor * tf.cast( _pad_dims(sequences['mask'], tensor.shape.ndims), tensor.dtype), sequences) # Compute open loop rollouts. use_obs = tf.zeros(tf.shape(sequences['mask']), tf.bool)[..., None] embed_size = nested.flatten(embedded)[0].shape[2].value obs = tf.zeros(shape.shape(sequences['mask']) + [embed_size]) prev_state = nested.map( lambda tensor: tf.concat([0 * tensor[:, :1], tensor[:, :-1]], 1), posterior) prev_state = nested.map(lambda tensor: _merge_dims(tensor, [0, 1]), prev_state) (priors, _), _ = tf.nn.dynamic_rnn(cell, (obs, sequences['prev_action'], use_obs), merged_length, prev_state) # Restore batch dimension. target, prior, posterior, mask = nested.map( functools.partial(_restore_batch_dim, batch_size=shape.shape(length)[0]), (sequences['target'], priors, sequences['posterior'], sequences['mask'])) mask = tf.cast(mask, tf.bool) return target, prior, posterior, mask