Example #1
0
 def _split_mu_log_sigma(self, params, slice_dim):
   """Splits `params` into `mu` and `log_sigma` along `slice_dim`."""
   params = tf.convert_to_tensor(params)
   size = params.get_shape()[slice_dim].value
   if size % 2 != 0:
     raise ValueError('`params` must have an even size along dimension {}.'
                      .format(slice_dim))
   half_size = size // 2
   mu = snt.SliceByDim(
       dims=[slice_dim], begin=[0], size=[half_size], name='mu')(params)
   log_sigma = snt.SliceByDim(
       dims=[slice_dim], begin=[half_size], size=[half_size],
       name='log_sigma')(params)
   return mu, log_sigma
    def _get_global_encoder(self):
        global_vars = self.model.game.global_vars.get_nominal()
        indices = []
        for i in range(len(global_vars)):
            var_index = snt.SliceByDim(dims=[1],
                                       begin=[i],
                                       size=[1],
                                       name=f"indices-{global_vars[i]}")
            indices.append(var_index)

        def global_encoder(global_input):
            state_layers = []
            n_local_features = 0
            for j in range(len(global_vars)):
                embedding_layer = self.global_embedding_layers[j]
                var_indices = indices[j](global_input)
                var_state = tf.squeeze(
                    embedding_layer(tf.to_int32(var_indices)),
                    axis=[1],
                    name=f"squeeze_embedding-{global_vars[j]}")
                n_local_features += embedding_layer.embed_dim
                state_layers.append(var_state)
                tf.summary.histogram("indices-" + embedding_layer.module_name,
                                     var_indices)
                self._summarize_embedding(embedding_layer)
            return tf.concat(state_layers, 1, name="embed_globals")

        return global_encoder
    def _get_node_encoder(self):
        # Embed both kinds of nodes: environment objects and legal actions
        object_vars = self.model.game.object_vars.get_nominal()
        action_vars = self.model.game.action_vars.get_nominal()
        object_indices = []
        action_indices = []
        for i in range(len(object_vars)):
            var_index = snt.SliceByDim(dims=[1],
                                       begin=[i],
                                       size=[1],
                                       name=f"indices-{object_vars[i]}")
            object_indices.append(var_index)
        for i in range(len(action_vars)):
            var_index = snt.SliceByDim(dims=[1],
                                       begin=[i + len(object_vars)],
                                       size=[1],
                                       name=f"indices-{action_vars[i]}")
            action_indices.append(var_index)

        def node_encoder(node_input):
            state_layers = []
            for j in range(len(object_vars)):
                embedding_layer = self.object_embedding_layers[j]
                var_indices = object_indices[j](node_input)
                var_state = tf.squeeze(
                    embedding_layer(tf.to_int32(var_indices)),
                    axis=[1],
                    name=f"squeeze_embedding-{object_vars[j]}")
                state_layers.append(var_state)
                tf.summary.histogram("indices-" + embedding_layer.module_name,
                                     var_indices)
                self._summarize_embedding(embedding_layer)
            for j in range(len(action_vars)):
                embedding_layer = self.action_embedding_layers[j]
                var_indices = action_indices[j](node_input)
                var_state = tf.squeeze(
                    embedding_layer(tf.to_int32(var_indices)),
                    axis=[1],
                    name=f"squeeze_embedding-{action_vars[j]}")
                state_layers.append(var_state)
                tf.summary.histogram("indices-" + embedding_layer.module_name,
                                     var_indices)
                self._summarize_embedding(embedding_layer)
            result = tf.concat(state_layers, 1, name="embed_nodes")
            return result

        return node_encoder