def _call(self, inputs, states, comm_indices, comm_directions, comm_distances): n_agents, max_others = util.get_shape_static_or_dynamic(comm_indices) if self.signal_size > 0: rnn_states, signals = states else: rnn_states = states signals = tf.zeros([n_agents, 0]) signal_sets, present_mask = util.gather_present(signals, comm_indices) if self.can_see_others: signal_sets = util.add_visible_agents_to_each_timestep( signal_sets, comm_directions, comm_distances) signal_sets = signal_sets[:, :self.max_agents, :] signal_sets = tf.pad( signal_sets, [ (0, 0), (0, self.max_agents - max_others), (0, 0), ], ) signals_flat = tf.reshape(signal_sets, [n_agents, self.signal_matrix_size]) full_inputs = tf.concat([inputs, signals_flat], axis=1) features, new_rnn_states = self.rnn_cell.call(full_inputs, rnn_states) if self.signal_size > 0: new_own_signal = self.signal_generator.call(features) out_states = new_rnn_states, new_own_signal else: out_states = new_rnn_states return features, out_states
def _call(self, inputs, states, comm_indices, comm_directions, comm_distances): rnn_states, signals = states if self._can_see_others: vis_features = util.build_visible_agents_features( comm_directions, comm_distances ) vis_features = util.average_by_mask( vis_features, tf.greater_equal(comm_indices, 0) ) inputs = tf.concat([vis_features, inputs], axis=1) env_features, rnn_states_after = self.rnn_cell.call(inputs, rnn_states) signal_sets, signals_mask = util.gather_present(signals, comm_indices) if self._can_see_others: signal_sets = util.add_visible_agents_to_each_timestep( signal_sets, comm_directions, comm_distances ) n_comm_steps = tf.shape(signal_sets)[1] env_features_broadcasted = tf.tile( tf.expand_dims(env_features, 1), [1, n_comm_steps, 1] ) full_comm_inputs = tf.concat([signal_sets, env_features_broadcasted], axis=2) if self._self_signal_decoder is not None: comm_init_raw = tf.concat([env_features, signals], axis=1) comm_init = self._self_signal_decoder(comm_init_raw) comm_init = util.decode_embedding(comm_init, self.comm_rnn.cell.state_size) _constants = () else: comm_init = None _constants = None out_features = self.comm_rnn.call( full_comm_inputs, signals_mask, initial_state=comm_init, constants=_constants, ) out_features = self._final_layer( tf.concat([env_features, out_features], axis=1) ) return out_features, (rnn_states_after, out_features)
def call(self, inputs, states=None, **kwargs): assert states is not None inputs, present_indices = inputs rnn_states, signals = nest.pack_sequence_as(self._state_structure, states) # unpack info about who and from where is we are communicating vis_dists = tf.cast(present_indices[:, 1::3], tf.float32) vis_dirs = present_indices[:, 2::3] present_indices = present_indices[:, 0::3] vis_dirs_onehot = tf.cast(tf.one_hot(vis_dirs, 4), tf.float32) n_agents, max_others = tf_utils.get_shape_static_or_dynamic(present_indices) # present_indices = tf.concat([ # tf.expand_dims(tf.range(n_agents)), # present_indices # ]) signals_sets, presence_mask = util.gather_present(signals, present_indices) signal_inputs = tf.concat( [signals_sets, tf.tile(tf.expand_dims(signals, 1), [1, max_others, 1])], axis=2, ) presence_mask_float = tf.cast(presence_mask, tf.float32) n_others = tf.maximum(tf.reduce_sum(presence_mask_float), 1.0) att_st1 = self._attention_st1(signal_inputs) att_ctx = tf.reduce_max(att_st1, axis=1, keepdims=True) / n_others att_st2_input = tf.concat( [signal_inputs, tf.tile(att_ctx, [1, max_others, 1])], axis=2 ) attention = self._attention_st2(att_st2_input) comm_features = tf.reduce_sum(attention * signals_sets, axis=1) / n_others full_input = tf.concat([inputs, comm_features], axis=1) output, rnn_states_after = self.rnn_cell.call(full_input, rnn_states) new_own_signal = output[:, :self.signal_size] states_after = rnn_states_after, new_own_signal full_output = tf.concat([new_own_signal, output], axis=1) return full_output, tuple(nest.flatten(states_after))
def _call(self, inputs, states, comm_indices, comm_directions, comm_distances): if self.version == 1: signals = states[-1][0] else: states, signals = states signal_seqs, mask = util.gather_present(signals, comm_indices) if self._can_see_others: signal_seqs = util.add_visible_agents_to_each_timestep( signal_seqs, comm_directions, comm_distances ) if self._self_signal_decoder is not None: comm_init = self._self_signal_decoder(signals) comm_init = util.decode_embedding(comm_init, self.comm_rnn.cell.state_size) _constants = () else: comm_init = None _constants = None comm_result = self.comm_rnn.call( signal_seqs, mask, initial_state=comm_init, constants=_constants ) if self._can_see_others: vis_features = util.build_visible_agents_features( comm_directions, comm_distances ) vis_features = util.average_by_mask( vis_features, tf.greater_equal(comm_indices, 0) ) inputs = tf.concat([vis_features, inputs], axis=1) full_input = tf.concat([inputs, comm_result], axis=1) features, *states_after = self.rnn_cell.call(full_input, states) if self.version == 2: states_after = states_after, features return features, states_after
def _gather_and_average(vectors: "n_agents vector_size", indices: "n_agents max_others"): vector_seqs, mask = util.gather_present(vectors, indices) return _average_by_mask(vector_seqs, mask)