Example #1
0
    def _call(self, inputs, states, comm_indices, comm_directions,
              comm_distances):
        n_agents, max_others = util.get_shape_static_or_dynamic(comm_indices)
        if self.signal_size > 0:
            rnn_states, signals = states
        else:
            rnn_states = states
            signals = tf.zeros([n_agents, 0])

        signal_sets, present_mask = util.gather_present(signals, comm_indices)
        if self.can_see_others:
            signal_sets = util.add_visible_agents_to_each_timestep(
                signal_sets, comm_directions, comm_distances)

        signal_sets = signal_sets[:, :self.max_agents, :]
        signal_sets = tf.pad(
            signal_sets,
            [
                (0, 0),
                (0, self.max_agents - max_others),
                (0, 0),
            ],
        )
        signals_flat = tf.reshape(signal_sets,
                                  [n_agents, self.signal_matrix_size])

        full_inputs = tf.concat([inputs, signals_flat], axis=1)
        features, new_rnn_states = self.rnn_cell.call(full_inputs, rnn_states)
        if self.signal_size > 0:
            new_own_signal = self.signal_generator.call(features)
            out_states = new_rnn_states, new_own_signal
        else:
            out_states = new_rnn_states
        return features, out_states
Example #2
0
    def _call(self, inputs, states, comm_indices, comm_directions, comm_distances):
        rnn_states, signals = states

        if self._can_see_others:
            vis_features = util.build_visible_agents_features(
                comm_directions, comm_distances
            )
            vis_features = util.average_by_mask(
                vis_features, tf.greater_equal(comm_indices, 0)
            )
            inputs = tf.concat([vis_features, inputs], axis=1)
        env_features, rnn_states_after = self.rnn_cell.call(inputs, rnn_states)

        signal_sets, signals_mask = util.gather_present(signals, comm_indices)
        if self._can_see_others:
            signal_sets = util.add_visible_agents_to_each_timestep(
                signal_sets, comm_directions, comm_distances
            )

        n_comm_steps = tf.shape(signal_sets)[1]
        env_features_broadcasted = tf.tile(
            tf.expand_dims(env_features, 1), [1, n_comm_steps, 1]
        )

        full_comm_inputs = tf.concat([signal_sets, env_features_broadcasted], axis=2)

        if self._self_signal_decoder is not None:
            comm_init_raw = tf.concat([env_features, signals], axis=1)
            comm_init = self._self_signal_decoder(comm_init_raw)
            comm_init = util.decode_embedding(comm_init, self.comm_rnn.cell.state_size)
            _constants = ()
        else:
            comm_init = None
            _constants = None
        out_features = self.comm_rnn.call(
            full_comm_inputs,
            signals_mask,
            initial_state=comm_init,
            constants=_constants,
        )
        out_features = self._final_layer(
            tf.concat([env_features, out_features], axis=1)
        )

        return out_features, (rnn_states_after, out_features)
Example #3
0
    def call(self, inputs, states=None, **kwargs):
        assert states is not None
        inputs, present_indices = inputs
        rnn_states, signals = nest.pack_sequence_as(self._state_structure, states)

        # unpack info about who and from where is we are communicating
        vis_dists = tf.cast(present_indices[:, 1::3], tf.float32)
        vis_dirs = present_indices[:, 2::3]
        present_indices = present_indices[:, 0::3]
        vis_dirs_onehot = tf.cast(tf.one_hot(vis_dirs, 4), tf.float32)

        n_agents, max_others = tf_utils.get_shape_static_or_dynamic(present_indices)

        # present_indices = tf.concat([
        #     tf.expand_dims(tf.range(n_agents)),
        #     present_indices
        # ])

        signals_sets, presence_mask = util.gather_present(signals, present_indices)

        signal_inputs = tf.concat(
            [signals_sets, tf.tile(tf.expand_dims(signals, 1), [1, max_others, 1])],
            axis=2,
        )

        presence_mask_float = tf.cast(presence_mask, tf.float32)
        n_others = tf.maximum(tf.reduce_sum(presence_mask_float), 1.0)
        att_st1 = self._attention_st1(signal_inputs)
        att_ctx = tf.reduce_max(att_st1, axis=1, keepdims=True) / n_others
        att_st2_input = tf.concat(
            [signal_inputs, tf.tile(att_ctx, [1, max_others, 1])], axis=2
        )
        attention = self._attention_st2(att_st2_input)

        comm_features = tf.reduce_sum(attention * signals_sets, axis=1) / n_others

        full_input = tf.concat([inputs, comm_features], axis=1)

        output, rnn_states_after = self.rnn_cell.call(full_input, rnn_states)
        new_own_signal = output[:, :self.signal_size]
        states_after = rnn_states_after, new_own_signal
        full_output = tf.concat([new_own_signal, output], axis=1)
        return full_output, tuple(nest.flatten(states_after))
Example #4
0
    def _call(self, inputs, states, comm_indices, comm_directions, comm_distances):
        if self.version == 1:
            signals = states[-1][0]
        else:
            states, signals = states

        signal_seqs, mask = util.gather_present(signals, comm_indices)
        if self._can_see_others:
            signal_seqs = util.add_visible_agents_to_each_timestep(
                signal_seqs, comm_directions, comm_distances
            )
        if self._self_signal_decoder is not None:
            comm_init = self._self_signal_decoder(signals)
            comm_init = util.decode_embedding(comm_init, self.comm_rnn.cell.state_size)
            _constants = ()
        else:
            comm_init = None
            _constants = None
        comm_result = self.comm_rnn.call(
            signal_seqs, mask, initial_state=comm_init, constants=_constants
        )

        if self._can_see_others:
            vis_features = util.build_visible_agents_features(
                comm_directions, comm_distances
            )
            vis_features = util.average_by_mask(
                vis_features, tf.greater_equal(comm_indices, 0)
            )
            inputs = tf.concat([vis_features, inputs], axis=1)
        full_input = tf.concat([inputs, comm_result], axis=1)

        features, *states_after = self.rnn_cell.call(full_input, states)
        if self.version == 2:
            states_after = states_after, features
        return features, states_after
Example #5
0
def _gather_and_average(vectors: "n_agents vector_size",
                        indices: "n_agents max_others"):
    vector_seqs, mask = util.gather_present(vectors, indices)
    return _average_by_mask(vector_seqs, mask)