Esempio n. 1
0
 def test_namedtuple(self):
     Foo = collections.namedtuple('Foo', 'value')
     foo, bar = [Foo(42)], [Foo(13)]
     function = nested.map(lambda x, y: (y, x), foo, bar)
     self.assertEqual([Foo((13, 42))], function)
     function = nested.map(lambda x, y: x + y, foo, bar)
     self.assertEqual([Foo(55)], function)
Esempio n. 2
0
def chunk_sequence(sequence, chunk_length, randomize=True, num_chunks=None):
    if 'length' in sequence:
        length = sequence.pop('length')
    else:
        length = tf.shape(nested.flatten(sequence)[0])[0]
    if randomize:
        if not num_chunks:
            num_chunks = tf.maximum(1, length // chunk_length - 1)
        else:
            num_chunks = num_chunks + 0 * length
        used_length = num_chunks * chunk_length
        max_offset = length - used_length
        offset = tf.random_uniform((), 0, max_offset + 1, dtype=tf.int32)
    else:
        if num_chunks is None:
            num_chunks = length // chunk_length
        else:
            num_chunks = num_chunks + 0 * length
        used_length = num_chunks * chunk_length
        offset = 0
    clipped = nested.map(lambda tensor: tensor[offset:offset + used_length],
                         sequence)
    chunks = nested.map(
        lambda tensor: tf.reshape(tensor, [num_chunks, chunk_length] + tensor.
                                  shape[1:].as_list()), clipped)
    return chunks
Esempio n. 3
0
def closed_loop(cell, embedded, prev_action, debug=False):
  use_obs = tf.ones(tf.shape(embedded[:, :, :1])[:3], tf.bool)
  (prior, posterior), _ = tf.nn.dynamic_rnn(
      cell, (embedded, prev_action, use_obs), dtype=tf.float32)
  if debug:
    with tf.control_dependencies([tf.assert_equal(
        tf.shape(nested.flatten(posterior)[0])[1], tf.shape(embedded)[1])]):
      prior = nested.map(tf.identity, prior)
      posterior = nested.map(tf.identity, posterior)
  return prior, posterior
Esempio n. 4
0
 def reset(self, agent_indices):
     state = nested.map(lambda tensor: tf.gather(tensor, agent_indices),
                        self._state)
     reset_state = nested.map(
         lambda var, val: tf.scatter_update(var, agent_indices, 0 * val),
         self._state,
         state,
         flatten=True)
     reset_prev_action = self._prev_action.assign(
         tf.zeros_like(self._prev_action))
     return tf.group(reset_prev_action, *reset_state)
Esempio n. 5
0
def open_loop(cell, embedded, prev_action, context=1, debug=False):
  use_obs = tf.ones(tf.shape(embedded[:, :context, :1])[:3], tf.bool)
  (_, closed_state), last_state = tf.nn.dynamic_rnn(
      cell, (embedded[:, :context], prev_action[:, :context], use_obs),
      dtype=tf.float32)
  use_obs = tf.zeros(tf.shape(embedded[:, context:, :1])[:3], tf.bool)
  (_, open_state), _ = tf.nn.dynamic_rnn(
      cell, (0 * embedded[:, context:], prev_action[:, context:], use_obs),
      initial_state=last_state)
  state = nested.map(
      lambda x, y: tf.concat([x, y], 1),
      closed_state, open_state)
  if debug:
    with tf.control_dependencies([tf.assert_equal(
        tf.shape(nested.flatten(state)[0])[1], tf.shape(embedded)[1])]):
      state = nested.map(tf.identity, state)
  return state
Esempio n. 6
0
 def step(self, agent_indices, observ):
     observ = self._config.preprocess_fn(observ)
     # Converts observ to sequence.
     observ = nested.map(lambda x: x[:, None], observ)
     embedded = self._config.encoder(observ)[:, 0]
     state = nested.map(lambda tensor: tf.gather(tensor, agent_indices),
                        self._state)
     prev_action = self._prev_action + 0
     with tf.control_dependencies([prev_action]):
         use_obs = tf.ones(tf.shape(agent_indices), tf.bool)[:, None]
         _, state = self._cell((embedded, prev_action, use_obs), state)
     action = self._config.planner(self._cell, self._config.objective,
                                   state, embedded.shape[1:].as_list(),
                                   prev_action.shape[1:].as_list())
     action = action[:, 0]
     if self._config.exploration:
         expl = self._config.exploration
         scale = tf.cast(expl.scale, tf.float32)[None]  # Batch dimension.
         if expl.schedule:
             scale *= expl.schedule(self._step)
         if expl.factors:
             scale *= np.array(expl.factors)
         if expl.type == 'additive_normal':
             action = tfd.Normal(action, scale[:, None]).sample()
         elif expl.type == 'epsilon_greedy':
             random_action = tf.one_hot(
                 tfd.Categorical(0 * action).sample(), action.shape[-1])
             switch = tf.cast(
                 tf.less(tf.random.uniform((self._num_envs, )), scale),
                 tf.float32)[:, None]
             action = switch * random_action + (1 - switch) * action
         else:
             raise NotImplementedError(expl.type)
     action = tf.clip_by_value(action, -1, 1)
     remember_action = self._prev_action.assign(action)
     remember_state = nested.map(
         lambda var, val: tf.scatter_update(var, agent_indices, val),
         self._state,
         state,
         flatten=True)
     with tf.control_dependencies(remember_state + (remember_action, )):
         return tf.identity(action)
Esempio n. 7
0
def planned(
    cell, objective_fn, embedded, prev_action, planner, context=1, length=20,
    amount=1000, debug=False):
  use_obs = tf.ones(tf.shape(embedded[:, :context, :1])[:3], tf.bool)
  (_, closed_state), last_state = tf.nn.dynamic_rnn(
      cell, (embedded[:, :context], prev_action[:, :context], use_obs),
      dtype=tf.float32)
  _, plan_state, return_ = planner(
      cell, objective_fn, last_state,
      obs_shape=shape.shape(embedded)[2:],
      action_shape=shape.shape(prev_action)[2:],
      horizon=length, amount=amount)
  state = nested.map(
      lambda x, y: tf.concat([x, y], 1),
      closed_state, plan_state)
  if debug:
    with tf.control_dependencies([tf.assert_equal(
        tf.shape(nested.flatten(state)[0])[1], context + length)]):
      state = nested.map(tf.identity, state)
      return_ = tf.identity(return_)
  return state, return_
Esempio n. 8
0
 def __init__(self, batch_env, step, is_training, should_log, config):
     self._step = step  # Trainer step, not environment step.
     self._is_training = is_training
     self._should_log = should_log
     self._config = config
     self._cell = config.cell
     self._num_envs = len(batch_env)
     state = self._cell.zero_state(self._num_envs, tf.float32)
     var_like = lambda x: tf.get_local_variable(
         x.name.split(':')[0].replace('/', '_') + '_var',
         shape=x.shape,
         initializer=lambda *_, **__: tf.zeros_like(x),
         use_resource=True)
     self._state = nested.map(var_like, state)
     batch_action_shape = (self._num_envs, ) + batch_env.action_space.shape
     self._prev_action = tf.get_local_variable(
         'prev_action_var',
         shape=batch_action_shape,
         initializer=lambda *_, **__: tf.zeros(batch_action_shape),
         use_resource=True)
Esempio n. 9
0
def _merge_dims(tensor, dims):
    if isinstance(tensor, (list, tuple, dict)):
        return nested.map(tensor, lambda x: _merge_dims(x, dims))
    tensor = tf.convert_to_tensor(tensor)
    if (np.array(dims) - min(dims) != np.arange(len(dims))).all():
        raise ValueError('Dimensions to merge must all follow each other.')
    start, end = dims[0], dims[-1]
    output = tf.reshape(
        tensor,
        tf.concat([
            tf.shape(tensor)[:start],
            [tf.reduce_prod(tf.shape(tensor)[start:end + 1])],
            tf.shape(tensor)[end + 1:]
        ],
                  axis=0))
    merged = tensor.shape[start:end + 1].as_list()
    output.set_shape(tensor.shape[:start].as_list() +
                     [None if None in merged else np.prod(merged)] +
                     tensor.shape[end + 1:].as_list())
    return output
Esempio n. 10
0
 def test_shallow_list(self):
     self.assertEqual([2, 4, 6], nested.map(lambda x: 2 * x, [1, 2, 3]))
Esempio n. 11
0
 def test_empty(self):
     self.assertEqual({}, nested.map(lambda x: x, {}))
Esempio n. 12
0
 def test_scalar(self):
     self.assertEqual(42, nested.map(lambda x: x, 42))
Esempio n. 13
0
 def test_multiple_lists(self):
     a = [1, 2, 3]
     b = [4, 5, 6]
     c = [7, 8, 9]
     result = nested.map(lambda x, y, z: x + y + z, a, b, c)
     self.assertEqual([12, 15, 18], result)
Esempio n. 14
0
 def test_mixed_types(self):
     self.assertEqual([14, 'foofoo'], nested.map(lambda x: x * 2,
                                                 [7, 'foo']))
Esempio n. 15
0
 def test_mixed_structure(self):
     structure = [(1, 2), 3, {'foo': [4]}]
     result = nested.map(lambda x: 2 * x, structure)
     self.assertEqual([(2, 4), 6, {'foo': [8]}], result)
Esempio n. 16
0
 def test_shallow_dict(self):
     data = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
     self.assertEqual(data, nested.map(lambda x: x, data))
Esempio n. 17
0
def overshooting(cell,
                 target,
                 embedded,
                 prev_action,
                 length,
                 amount,
                 posterior=None,
                 ignore_input=False):
    # Closed loop unroll to get posterior states, which are the starting points
    # for open loop unrolls. We don't need the last time step, since we have no
    # targets for unrolls from it.
    if posterior is None:
        use_obs = tf.ones(
            tf.shape(nested.flatten(embedded)[0][:, :, :1])[:3], tf.bool)
        use_obs = tf.cond(tf.convert_to_tensor(ignore_input),
                          lambda: tf.zeros_like(use_obs, tf.bool),
                          lambda: use_obs)
        (_, posterior), _ = tf.nn.dynamic_rnn(cell,
                                              (embedded, prev_action, use_obs),
                                              length,
                                              dtype=tf.float32,
                                              swap_memory=True)

    # Arrange inputs for every iteration in the open loop unroll. Every loop
    # iteration below corresponds to one row in the docstring illustration.
    max_length = shape.shape(nested.flatten(embedded)[0])[1]
    first_output = {
        # 'observ': embedded,
        'prev_action': prev_action,
        'posterior': posterior,
        'target': target,
        'mask': tf.sequence_mask(length, max_length, tf.int32),
    }

    progress_fn = lambda tensor: tf.concat([tensor[:, 1:], 0 * tensor[:, :1]],
                                           1)
    other_outputs = tf.scan(
        lambda past_output, _: nested.map(progress_fn, past_output),
        tf.range(amount), first_output)
    sequences = nested.map(lambda lhs, rhs: tf.concat([lhs[None], rhs], 0),
                           first_output, other_outputs)

    # Merge batch and time dimensions of steps to compute unrolls from every
    # time step as one batch. The time dimension becomes the number of
    # overshooting distances.
    sequences = nested.map(lambda tensor: _merge_dims(tensor, [1, 2]),
                           sequences)
    sequences = nested.map(
        lambda tensor: tf.transpose(tensor, [1, 0] + list(
            range(2, tensor.shape.ndims))), sequences)
    merged_length = tf.reduce_sum(sequences['mask'], 1)

    # Mask out padding frames; unnecessary if the input is already masked.
    sequences = nested.map(
        lambda tensor: tensor * tf.cast(
            _pad_dims(sequences['mask'], tensor.shape.ndims), tensor.dtype),
        sequences)

    # Compute open loop rollouts.
    use_obs = tf.zeros(tf.shape(sequences['mask']), tf.bool)[..., None]
    embed_size = nested.flatten(embedded)[0].shape[2].value
    obs = tf.zeros(shape.shape(sequences['mask']) + [embed_size])
    prev_state = nested.map(
        lambda tensor: tf.concat([0 * tensor[:, :1], tensor[:, :-1]], 1),
        posterior)
    prev_state = nested.map(lambda tensor: _merge_dims(tensor, [0, 1]),
                            prev_state)
    (priors,
     _), _ = tf.nn.dynamic_rnn(cell, (obs, sequences['prev_action'], use_obs),
                               merged_length, prev_state)

    # Restore batch dimension.
    target, prior, posterior, mask = nested.map(
        functools.partial(_restore_batch_dim,
                          batch_size=shape.shape(length)[0]),
        (sequences['target'], priors, sequences['posterior'],
         sequences['mask']))

    mask = tf.cast(mask, tf.bool)
    return target, prior, posterior, mask