def _split_mu_log_sigma(self, params, slice_dim): """Splits `params` into `mu` and `log_sigma` along `slice_dim`.""" params = tf.convert_to_tensor(params) size = params.get_shape()[slice_dim].value if size % 2 != 0: raise ValueError('`params` must have an even size along dimension {}.' .format(slice_dim)) half_size = size // 2 mu = snt.SliceByDim( dims=[slice_dim], begin=[0], size=[half_size], name='mu')(params) log_sigma = snt.SliceByDim( dims=[slice_dim], begin=[half_size], size=[half_size], name='log_sigma')(params) return mu, log_sigma
def _get_global_encoder(self): global_vars = self.model.game.global_vars.get_nominal() indices = [] for i in range(len(global_vars)): var_index = snt.SliceByDim(dims=[1], begin=[i], size=[1], name=f"indices-{global_vars[i]}") indices.append(var_index) def global_encoder(global_input): state_layers = [] n_local_features = 0 for j in range(len(global_vars)): embedding_layer = self.global_embedding_layers[j] var_indices = indices[j](global_input) var_state = tf.squeeze( embedding_layer(tf.to_int32(var_indices)), axis=[1], name=f"squeeze_embedding-{global_vars[j]}") n_local_features += embedding_layer.embed_dim state_layers.append(var_state) tf.summary.histogram("indices-" + embedding_layer.module_name, var_indices) self._summarize_embedding(embedding_layer) return tf.concat(state_layers, 1, name="embed_globals") return global_encoder
def _get_node_encoder(self): # Embed both kinds of nodes: environment objects and legal actions object_vars = self.model.game.object_vars.get_nominal() action_vars = self.model.game.action_vars.get_nominal() object_indices = [] action_indices = [] for i in range(len(object_vars)): var_index = snt.SliceByDim(dims=[1], begin=[i], size=[1], name=f"indices-{object_vars[i]}") object_indices.append(var_index) for i in range(len(action_vars)): var_index = snt.SliceByDim(dims=[1], begin=[i + len(object_vars)], size=[1], name=f"indices-{action_vars[i]}") action_indices.append(var_index) def node_encoder(node_input): state_layers = [] for j in range(len(object_vars)): embedding_layer = self.object_embedding_layers[j] var_indices = object_indices[j](node_input) var_state = tf.squeeze( embedding_layer(tf.to_int32(var_indices)), axis=[1], name=f"squeeze_embedding-{object_vars[j]}") state_layers.append(var_state) tf.summary.histogram("indices-" + embedding_layer.module_name, var_indices) self._summarize_embedding(embedding_layer) for j in range(len(action_vars)): embedding_layer = self.action_embedding_layers[j] var_indices = action_indices[j](node_input) var_state = tf.squeeze( embedding_layer(tf.to_int32(var_indices)), axis=[1], name=f"squeeze_embedding-{action_vars[j]}") state_layers.append(var_state) tf.summary.histogram("indices-" + embedding_layer.module_name, var_indices) self._summarize_embedding(embedding_layer) result = tf.concat(state_layers, 1, name="embed_nodes") return result return node_encoder