예제 #1
0
 def __init__(self,
              edge_output_size=None,
              node_output_size=None,
              global_output_size=None,
              name="EncodeProcessDecode"):
     super(EncodeProcessDecode, self).__init__(name=name)
     self._embed = snt.Embed(vocab_size=1267, embed_dim=LATENT_SIZE)
     self._embed_edge = snt.Embed(vocab_size=1, embed_dim=LATENT_SIZE)
     self._embed_global = snt.Embed(vocab_size=1, embed_dim=LATENT_SIZE)
     self._encoder = MLPGraphIndependent()
     self._core = MLPGraphNetwork()
     self._decoder = MLPGraphIndependent()
     # Transforms the outputs into the appropriate shapes.
     if edge_output_size is None:
         edge_fn = None
     else:
         edge_fn = lambda: snt.Linear(edge_output_size, name="edge_output")
     if node_output_size is None:
         node_fn = None
     else:
         node_fn = lambda: snt.Linear(node_output_size, name="node_output")
     if global_output_size is None:
         global_fn = None
     else:
         global_fn = lambda: snt.Linear(global_output_size,
                                        name="global_output")
     self._output_transform = modules.GraphIndependent(
         edge_fn, node_fn, global_fn)
예제 #2
0
    def testDefaultVocabSize(self):
        embed_mod = snt.Embed(vocab_size=100,
                              embed_dim=None,
                              name="embed_small")
        self.assertEqual(embed_mod.embed_dim, 19)

        embed_mod = snt.Embed(vocab_size=1000000,
                              embed_dim=None,
                              name="embed_big")
        self.assertEqual(embed_mod.embed_dim, 190)
예제 #3
0
    def testInvalidRegularizationParameters(self):
        regularizer = tf.contrib.layers.l1_regularizer(scale=0.5)
        with self.assertRaisesRegexp(KeyError, "Invalid regularizer keys.*"):
            snt.Embed(vocab_size=self._vocab_size,
                      embed_dim=self._embed_dim,
                      regularizers={"not_embeddings": regularizer})

        err = "Regularizer for 'embeddings' is not a callable function"
        with self.assertRaisesRegexp(TypeError, err):
            snt.Embed(vocab_size=self._vocab_size,
                      embed_dim=self._embed_dim,
                      regularizers={"embeddings": tf.zeros([1, 2, 3])})
예제 #4
0
 def setUp(self):
     super(EmbedTest, self).setUp()
     self._batch_size = 3
     self._vocab_size = 7
     self._embed_dim = 1
     self._embed_mod = snt.Embed(vocab_size=self._vocab_size,
                                 embed_dim=self._embed_dim)
     self._embed_mod_dense = snt.Embed(vocab_size=self._vocab_size,
                                       embed_dim=self._embed_dim,
                                       densify_gradients=True)
     self._ids = np.asarray([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5],
                             [2, 3, 4, 5, 6]])
    def _construct_weights(self):
        """
        Constructs the user/item memories and user/item external memory/outputs

        Also add the embedding lookups
        """
        self.user_memory = snt.Embed(self.config.user_count,
                                     self.config.embed_size,
                                     initializers=self._embedding_initializers,
                                     name='MemoryEmbed')

        self.user_output = snt.Embed(self.config.user_count,
                                     self.config.embed_size,
                                     initializers=self._embedding_initializers,
                                     name='MemoryOutput')

        self.item_memory = snt.Embed(self.config.item_count,
                                     self.config.embed_size,
                                     initializers=self._embedding_initializers,
                                     name="ItemMemory")
        self._mem_layer = VariableLengthMemoryLayer(
            self.config.hops,
            self.config.embed_size,
            tf.nn.relu,
            initializers=self._hops_init,
            regularizers=self._regularizers,
            name='UserMemoryLayer')

        self._output_module = snt.Sequential([
            DenseLayer(self.config.embed_size,
                       True,
                       tf.nn.relu,
                       initializers=self._initializers,
                       regularizers=self._regularizers,
                       name='Layer'),
            snt.Linear(1,
                       False,
                       initializers=self._output_initializers,
                       regularizers=self._regularizers,
                       name='OutputVector'), tf.squeeze
        ])

        # [batch, embedding size]
        self._cur_user = self.user_memory(self.input_users)
        self._cur_user_output = self.user_output(self.input_users)

        # Item memories a query
        self._cur_item = self.item_memory(self.input_items)
        self._cur_item_negative = self.item_memory(self.input_items_negative)

        # Share Embeddings
        self._cur_item_output = self._cur_item
        self._cur_item_output_negative = self._cur_item_negative
    def __init__(self,
                 edge_output_size=None,
                 node_output_size=None,
                 global_output_size=None,
                 name="EncodeProcessDecode",
                 VOCABSZ_SYMPTOM=1000):
        super(EncodeProcessDecode, self).__init__(name=name)

        ## embed layer
        self._embed_mod_symptomNode = snt.Embed(vocab_size=VOCABSZ_SYMPTOM,
                                                embed_dim=LATENT_SIZE,
                                                name='symptomNode')
        # self._embed_mod_ageGender=snt.Embed(
        #   vocab_size=VOCABSZ_AGEGENDER, embed_dim=LATENT_SIZE/2,name='ageGender')
        self._embed_mod_global = snt.Embed(vocab_size=1,
                                           embed_dim=LATENT_SIZE_SMALL,
                                           name='global')
        self._embed_mod_edge = snt.Embed(vocab_size=1,
                                         embed_dim=LATENT_SIZE_SMALL,
                                         name='edge')

        self._encoder = MLPGraphIndependent(name='enc')
        self._core = MLPGraphNetwork(name='core')
        self._decoder = MLPGraphIndependent(name='dec')
        # Transforms the outputs into the appropriate shapes.
        if edge_output_size is None:
            edge_fn = None
        else:
            #edge_fn = lambda: snt.Linear(edge_output_size, name="edge_output")
            edge_fn = snt.Linear(edge_output_size, name="edge_output")
        if node_output_size is None:
            node_fn = None
        else:
            #node_fn = lambda: snt.Linear(node_output_size, name="node_output")
            node_fn = snt.Linear(node_output_size, name="node_output")
        if global_output_size is None:
            global_fn = None
        else:
            #global_fn = lambda: snt.Linear(global_output_size, name="global_output")
            global_fn = snt.Linear(global_output_size, name="global_output")
        #with self._enter_variable_scope():
        with tf.compat.v1.variable_scope("out_transform"):
            #if 2>1:
            self._output_transform = modules.GraphIndependent(
                edge_fn,
                node_fn,
                global_fn,
                #name='out_transform'
            )
예제 #7
0
    def _build(batch):
        """Build the loss sonnet module."""
        # TODO(lmetz) these are dense updates.... so keeping this small for now.
        tokens = tf.minimum(batch["text"],
                            tf.to_int64(tf.reshape(vocab_size - 1, [1, 1])))
        embed = snt.Embed(vocab_size=vocab_size, embed_dim=embed_dim)
        embedded_tokens = embed(tokens)
        rnn = core_fn()
        bs = tokens.shape.as_list()[0]

        state = rnn.initial_state(bs, trainable=True)

        outputs, state = tf.nn.dynamic_rnn(rnn,
                                           embedded_tokens,
                                           initial_state=state)
        if aggregate_method == "last":
            last_output = outputs[:, -1]  # grab the last output
        elif aggregate_method == "avg":
            last_output = tf.reduce_mean(outputs, [1])  # average over length
        else:
            raise ValueError("not supported aggregate_method")

        logits = snt.Linear(2)(last_output)

        loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=batch["label_onehot"], logits=logits)

        return tf.reduce_mean(loss_vec)
예제 #8
0
    def _fn(batch):
        """Compute the loss from the given batch."""
        # Shape is [bs, seq len, features]
        inp = batch["text"]
        mask = batch["mask"]
        embed = snt.Embed(vocab_size=256, embed_dim=embed_dim)
        embedded_chars = embed(inp)

        rnn = core_fn()
        bs = inp.shape.as_list()[0]

        state = rnn.initial_state(bs, trainable=True)

        outputs, state = tf.nn.dynamic_rnn(rnn,
                                           embedded_chars,
                                           initial_state=state)

        pred_logits = snt.BatchApply(snt.Linear(256))(outputs[:, :-1])
        actual_tokens = inp[:, 1:]

        flat_s = [
            pred_logits.shape[0] * pred_logits.shape[1], pred_logits.shape[2]
        ]
        f_pred_logits = tf.reshape(pred_logits, flat_s)
        f_actual_tokens = tf.reshape(actual_tokens, [flat_s[0]])
        f_mask = tf.reshape(mask[:, 1:], [flat_s[0]])

        loss_vec = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=f_actual_tokens, logits=f_pred_logits)
        total_loss = tf.reduce_sum(f_mask * loss_vec)
        mean_loss = total_loss / tf.reduce_sum(f_mask)

        return mean_loss
예제 #9
0
파일: embedding.py 프로젝트: lolski/kglib
def common_embedding(features, num_types, type_embedding_dim):
    preexistance_feat = tf.expand_dims(tf.cast(features[:, 0],
                                               dtype=tf.float32),
                                       axis=1)
    type_embedder = snt.Embed(num_types, type_embedding_dim)
    type_embedding = type_embedder(tf.cast(features[:, 1], tf.int32))
    return tf.concat([preexistance_feat, type_embedding], axis=1)
예제 #10
0
def compute_embedding(objective, config):

    objective = tf.cast(objective, dtype=tf.int32)

    # if objective is rank 3 -> one-hot encoded (batch_dim, sequence, vocab_size)
    # else objective is rank 2 -> need to compute embedding (batch_dim, sequence_of_embedding)
    if config["embedding_size"] > 0:
        embedded_obj = snt.Embed(
            vocab_size=config["max_size_vocab"] + 1,
            embed_dim=config["embedding_size"],
            densify_gradients=True,
        )(objective)

    else:  # embedding_size == 0 -> onehot vector directly to lstm

        # Vocab_size or number of objective for this env if you don't use language
        depth = config.get("vocab_size", 2)

        embedded_obj = tf.one_hot(indices=objective,
                                  depth=depth,
                                  on_value=1.0,
                                  off_value=0.0,
                                  axis=1)

        if depth == 2:
            embedded_obj = tf.squeeze(embedded_obj, axis=2)

    return embedded_obj
예제 #11
0
    def _instruction(self, instruction):
        # Split string.
        splitted = tf.string_split(instruction)
        dense = tf.sparse_tensor_to_dense(splitted, default_value='')
        length = tf.reduce_sum(tf.to_int32(tf.not_equal(dense, '')), axis=1)

        # To int64 hash buckets. Small risk of having collisions. Alternatively, a
        # vocabulary can be used.
        num_hash_buckets = 1000
        buckets = tf.string_to_hash_bucket_fast(dense, num_hash_buckets)

        # Embed the instruction. Embedding size 20 seems to be enough.
        embedding_size = 20
        embedding = snt.Embed(num_hash_buckets, embedding_size)(buckets)

        # Pad to make sure there is at least one output.
        padding = tf.to_int32(tf.equal(tf.shape(embedding)[1], 0))
        embedding = tf.pad(embedding, [[0, 0], [0, padding], [0, 0]])

        core = tf.contrib.rnn.LSTMBlockCell(64, name='language_lstm')
        output, _ = tf.nn.dynamic_rnn(core,
                                      embedding,
                                      length,
                                      dtype=tf.float32)

        # Return last output.
        return tf.reverse_sequence(output, length, seq_axis=1)[:, 0]
예제 #12
0
    def __init__(self, vocab_size, embed_dim):
        super(PositionnalEmbedding, self).__init__(name="positional_embedding")

        with self._enter_variable_scope():
            self.embed = snt.Embed(vocab_size=vocab_size,
                                   embed_dim=embed_dim,
                                   name="embedding")
  def _build(batch):
    """Builds the sonnet module."""
    # Shape is [batch size, sequence length]
    inp = batch["text"]

    # Clip the vocab to be at most vocab_size.
    inp = tf.minimum(inp,
                     tf.to_int64(tf.reshape(cfg["vocab_size"] - 1, [1, 1])))

    embed = snt.Embed(vocab_size=cfg["vocab_size"], embed_dim=cfg["embed_dim"])
    embedded_chars = embed(inp)

    rnn = utils.get_rnn_core(cfg["core"])
    batch_size = inp.shape.as_list()[0]

    state = rnn.initial_state(batch_size, trainable=cfg["trainable_init"])

    outputs, state = tf.nn.dynamic_rnn(rnn, embedded_chars, initial_state=state)

    w_init = utils.get_initializer(cfg["w_init"])
    pred_logits = snt.BatchApply(
        snt.Linear(cfg["vocab_size"], initializers={"w": w_init}))(
            outputs[:, :-1])
    actual_output_tokens = inp[:, 1:]

    flat_s = [pred_logits.shape[0] * pred_logits.shape[1], pred_logits.shape[2]]
    flat_pred_logits = tf.reshape(pred_logits, flat_s)
    flat_actual_tokens = tf.reshape(actual_output_tokens, [flat_s[0]])

    loss_vec = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=flat_actual_tokens, logits=flat_pred_logits)
    return tf.reduce_mean(loss_vec)
예제 #14
0
 def _make_encoder(self):
     """Constructs an encoding for a single character ID."""
     embed = snt.Embed(vocab_size=self.hparams.vocab_size +
                       self.hparams.oov_buckets,
                       embed_dim=self.hparams.embed_size)
     mlp = codec_mod.MLPObsEncoder(self.hparams)
     return codec_mod.EncoderSequence([embed, mlp], name="obs_encoder")
예제 #15
0
def common_embedding(features, num_types, type_embedding_dim):
    preexistance_feat = tf.expand_dims(tf.cast(features[:, 0], dtype=tf.float32), axis=1)
    type_embedder = snt.Embed(num_types, type_embedding_dim)
    norm = snt.LayerNorm()
    type_embedding = norm(type_embedder(tf.cast(features[:, 1], tf.int32)))
    tf.summary.histogram('type_embedding_histogram', type_embedding)
    return tf.concat([preexistance_feat, type_embedding], axis=1)
예제 #16
0
파일: attribute.py 프로젝트: hkuich/kglib
 def _build(self, attribute_value):
     int_attribute_value = tf.cast(attribute_value, dtype=tf.int32)
     tf.summary.histogram('cat_attribute_value_histogram',
                          int_attribute_value)
     embedding = snt.Embed(self._num_categories,
                           self._attr_embedding_dim)(int_attribute_value)
     tf.summary.histogram('cat_embedding_histogram', embedding)
     return tf.squeeze(embedding, axis=1)
예제 #17
0
    def testRegularizersInRegularizationLosses(self):
        regularizer = tf.contrib.layers.l1_regularizer(scale=0.5)
        embed = snt.Embed(vocab_size=self._vocab_size,
                          embed_dim=self._embed_dim,
                          regularizers={"embeddings": regularizer})
        embed(tf.convert_to_tensor(self._ids))

        regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        self.assertRegexpMatches(regularizers[0].name, ".*l1_regularizer.*")
예제 #18
0
    def __init__(self,
                 edge_output_size=None,
                 node_output_size=None,
                 global_output_size=None,
                 name="EncodeProcessDecode"):
        super(EncodeProcessDecode, self).__init__(name=name)

        ###### input -> embed [batch,node] -> [batch node embed]
        self._embed_mod_symptomNode = snt.Embed(vocab_size=VOCABSZ_SYMPTOM,
                                                embed_dim=LATENT_SIZE,
                                                name='symptomNode')
        # self._embed_mod_ageGender=snt.Embed(
        #   vocab_size=VOCABSZ_AGEGENDER, embed_dim=LATENT_SIZE/2,name='ageGender')
        self._embed_mod_global = snt.Embed(vocab_size=1,
                                           embed_dim=LATENT_SIZE_SMALL,
                                           name='global')
        self._embed_mod_edge = snt.Embed(vocab_size=1,
                                         embed_dim=LATENT_SIZE_SMALL,
                                         name='edge')

        self._encoder = MLPGraphIndependent()
        self._core = MLPGraphNetwork()
        self._decoder = MLPGraphIndependent()

        #self.step_mlp = get_mlp_model_steps()
        #self.mapToSym = map2sym()

        # Transforms the outputs into the appropriate shapes.
        if edge_output_size is None:
            edge_fn = None
        else:
            edge_fn = lambda: snt.Linear(edge_output_size, name="edge_output")
        if node_output_size is None:
            node_fn = None
        else:
            node_fn = lambda: snt.Linear(node_output_size, name="node_output")
        if global_output_size is None:
            global_fn = None
        else:
            global_fn = lambda: snt.Linear(global_output_size,
                                           name="global_output")
        with self._enter_variable_scope():
            self._output_transform = modules.GraphIndependent(
                edge_fn, node_fn, global_fn)
예제 #19
0
 def _build_embedding_layers(self, var_model, dimension_map, prefix):
     embedding_layers = []
     for var_name in var_model.get_nominal():
         embedding = snt.Embed(
             vocab_size=var_model.get_num_values(var_name) + 1,
             embed_dim=dimension_map[var_name],
             regularizers=self.regularizers,
             # TODO: densify_gradients?
             name=f"{prefix}-{var_name}")
         embedding_layers.append(embedding)
     return embedding_layers
예제 #20
0
 def testInitializers(self):
     # Since all embeddings are initialized to zero, the extracted embeddings
     # should be as well.
     initializers = {"embeddings": tf.zeros_initializer()}
     embed_mod = snt.Embed(vocab_size=self._vocab_size,
                           embed_dim=self._embed_dim,
                           initializers=initializers)
     embeddings = embed_mod(tf.convert_to_tensor(self._ids))
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         embeddings_ = sess.run(embeddings)
         self.assertAllEqual(embeddings_, np.zeros_like(embeddings_))
예제 #21
0
    def _construct_weights(self):
        """
        Constructs the user/item memories and user/item external memory/outputs
        Also add the embedding lookups
        """
        self.user_memory = snt.Embed(self.config.user_count,
                                     self.config.embed_size,
                                     initializers=self._embedding_initializers,
                                     regularizers=self._embedding_regularizers,
                                     name='MemoryEmbed')

        self.item_memory = snt.Embed(self.config.item_count,
                                     self.config.embed_size,
                                     initializers=self._embedding_initializers,
                                     regularizers=self._embedding_regularizers,
                                     name="ItemMemory")

        # [batch, embedding size]
        self._cur_user = self.user_memory(self.input_users)

        # Item memories a query
        self._cur_item = self.item_memory(self.input_items)
        self._cur_item_negative = self.item_memory(self.input_items_negative)
예제 #22
0
 def __init__(self, vocab_size, embed_dim, max_word_length=None, initializer=None, trainable=True, name="char_embed"):
     super(CharEmbed, self).__init__(name=name)
     self._vocab_size = vocab_size# char_vocab.size
     self._embed_dim = embed_dim
     self._max_word_length = max_word_length
     self._initializers = init_dict(initializer, ['embeddings'])
     self._trainable = trainable
     
     with self._enter_variable_scope():# cuz in init (not build)...
         self._char_embedding = snt.Embed(vocab_size=self._vocab_size,
                                          embed_dim=self._embed_dim,
                                          trainable=self._trainable,
                                          #initializers={'embeddings':tf.constant_initializer(0.)},
                                          name="internal_embed")
예제 #23
0
 def __init__(self, vocab_size=None, embed_dim=None, initial_matrix=None, trainable=True, name="word_embed"):
     super(WordEmbed, self).__init__(name=name)
     self._vocab_size = vocab_size# word_vocab.size
     self._embed_dim = embed_dim
     self._trainable = trainable
     if initial_matrix:
         self._vocab_size = initial_matrix.shape[0]
         self._embed_dim = initial_matrix.shape[1]
     
     with self._enter_variable_scope():# cuz in init (not build)...
         self._embedding = snt.Embed(vocab_size=self._vocab_size,
                                     embed_dim=self._embed_dim,
                                     trainable=self._trainable,
                                     name="internal_embed")
예제 #24
0
    def _build(self, sequence, sequence_length, is_train, previous_state=None):
        """ Add to graph.

        :param sequence:
        :param sequence_length:
        :param support:
        :param channel_state:
        :param previous_state:
        :param is_train:
        :return:
        """
        batch_size = tf.shape(sequence)[0]

        # Convert sequence from token IDs to embeddings.
        if self.token_embedder is None:
            self.token_embedder = snt.Embed(
                vocab_size=self.params["embed.vocab_size"],
                embed_dim=self.params["embed.embed_size"])
        sequence = self.token_embedder(sequence)

        # Get an representation of the read sequence.
        sequence_encoder = tf.nn.rnn_cell.LSTMCell(
            self.params["sequence_encoder.n_units"])
        self.state_size = sum(sequence_encoder.state_size)

        if previous_state is not None:
            init_state = previous_state
            init_state = tf.split(
                init_state,
                len(nest.flatten(sequence_encoder.state_size)),
                axis=1)
            init_state = nest.pack_sequence_as(
                structure=sequence_encoder.state_size,
                flat_sequence=init_state)
        else:
            init_state = sequence_encoder.zero_state(
                batch_size=batch_size,
                dtype=tf.float32)

        _, current_state = tf.nn.dynamic_rnn(
            cell=sequence_encoder,
            inputs=sequence,
            sequence_length=sequence_length,
            initial_state=init_state)

        current_state = nest.flatten(current_state)
        current_state = tf.concat(current_state, axis=1)
        current_state = tf.reshape(current_state, [batch_size, self.state_size])
        return current_state
예제 #25
0
    def _build(batch):
        """Build the sonnet module.

    Args:
      batch: A dictionary with keys "label", "label_onehot", and "text" mapping
        to tensors. The "text" consists of int tokens. These tokens are
        truncated to the length of the vocab before performing an embedding
        lookup.

    Returns:
      Loss of the batch.
    """
        vocab_size = cfg["vocab_size"]
        max_token = cfg["dataset"][1]["max_token"]
        if max_token:
            vocab_size = min(max_token, vocab_size)

        # Clip the max token to be vocab_size-1.
        tokens = tf.minimum(tf.to_int32(batch["text"]),
                            tf.to_int32(tf.reshape(vocab_size - 1, [1, 1])))

        embed = snt.Embed(vocab_size=vocab_size, embed_dim=cfg["embed_dim"])
        embedded_tokens = embed(tokens)
        rnn = utils.get_rnn_core(cfg["core"])

        batch_size = tokens.shape.as_list()[0]

        state = rnn.initial_state(batch_size, trainable=cfg["trainable_init"])

        outputs, _ = tf.nn.dynamic_rnn(rnn,
                                       embedded_tokens,
                                       initial_state=state)
        if cfg["loss_compute"] == "last":
            rnn_output = outputs[:, -1]  # grab the last output
        elif cfg["loss_compute"] == "avg":
            rnn_output = tf.reduce_mean(outputs, 1)  # average over length
        elif cfg["loss_compute"] == "max":
            rnn_output = tf.reduce_max(outputs, 1)
        else:
            raise ValueError("Not supported loss_compute [%s]" %
                             cfg["loss_compute"])

        logits = snt.Linear(batch["label_onehot"].shape[1],
                            initializers=init)(rnn_output)

        loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=batch["label_onehot"], logits=logits)
        return tf.reduce_mean(loss_vec)
예제 #26
0
    def testExistingVocab(self):
        # Check that the module can be initialised with an existing vocabulary.
        existing = np.array([[1, 0, 0, 0], [1, 0, 1, 0], [0, 1, 0, 1]],
                            dtype=np.int32)
        expected = np.array([[1, 0, 0, 0], [0, 1, 0, 1], [1, 0, 1, 0]],
                            dtype=np.int32)
        true_vocab_size, true_embed_dim = existing.shape

        inputs = tf.constant(np.array([0, 2, 1]), dtype=tf.int32)
        embed_mod = snt.Embed(existing_vocab=existing)
        embeddings = embed_mod(inputs)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            embeddings_ = sess.run(embeddings)
            self.assertAllClose(embeddings_, expected)
            self.assertEqual(embed_mod.vocab_size, true_vocab_size)
            self.assertEqual(embed_mod.embed_dim, true_embed_dim)
예제 #27
0
def build_modules(is_training, vocab_size):
    """Construct the modules used in the graph."""

    # Construct the custom getter which implements Bayes by Backprop.
    if is_training:
        estimator_mode = tf.constant(bbb.EstimatorModes.sample)
    else:
        estimator_mode = tf.constant(bbb.EstimatorModes.mean)
    lstm_bbb_custom_getter = bbb.bayes_by_backprop_getter(
        posterior_builder=lstm_posterior_builder,
        prior_builder=custom_scale_mixture_prior_builder,
        kl_builder=bbb.stochastic_kl_builder,
        sampling_mode_tensor=estimator_mode,
    )
    non_lstm_bbb_custom_getter = bbb.bayes_by_backprop_getter(
        posterior_builder=non_lstm_posterior_builder,
        prior_builder=custom_scale_mixture_prior_builder,
        kl_builder=bbb.stochastic_kl_builder,
        sampling_mode_tensor=estimator_mode,
    )

    embed_layer = snt.Embed(
        vocab_size=vocab_size,
        embed_dim=FLAGS.embedding_size,
        custom_getter=non_lstm_bbb_custom_getter,
        name="input_embedding",
    )

    cores = [
        snt.LSTM(
            FLAGS.hidden_size,
            custom_getter=lstm_bbb_custom_getter,
            forget_bias=0.0,
            name="lstm_layer_{}".format(i),
        ) for i in six.moves.range(FLAGS.n_layers)
    ]
    rnn_core = snt.DeepRNN(cores,
                           skip_connections=False,
                           name="deep_lstm_core")

    # Do BBB on weights but not biases of output layer.
    output_linear = snt.Linear(vocab_size,
                               custom_getter={"w": non_lstm_bbb_custom_getter})
    return embed_layer, rnn_core, output_linear
예제 #28
0
    def testPartitioners(self):
        # Partition embeddings such that there's one variable per vocabulary entry.
        partitioners = {
            "embeddings":
            tf.variable_axis_size_partitioner(4 * self._embed_dim)
        }
        embed_mod = snt.Embed(vocab_size=self._vocab_size,
                              embed_dim=self._embed_dim,
                              partitioners=partitioners)
        embeddings = embed_mod(tf.convert_to_tensor(self._ids))
        self.assertEqual(type(embed_mod.embeddings),
                         variables.PartitionedVariable)
        self.assertEqual(len(embed_mod.embeddings), self._vocab_size)

        # Ensure that tf.nn.embedding_lookup() plays nicely with embedding
        # variables.
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(embeddings)
예제 #29
0
    def testComputation(self):
        # Initialize each embedding to its index. Thus, the lookup ids are the same
        # as the embeddings themselves.
        initializers = {
            "embeddings":
            tf.constant_initializer([[0], [1], [2], [3], [4], [5], [6]],
                                    dtype=tf.float32)
        }
        embed_mod = snt.Embed(vocab_size=self._vocab_size,
                              embed_dim=self._embed_dim,
                              initializers=initializers)
        embeddings = embed_mod(tf.convert_to_tensor(self._ids))

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            embeddings_ = sess.run(embeddings)
            expected_embeddings = np.reshape(self._ids,
                                             newshape=list(self._ids.shape) +
                                             [self._embed_dim])
            self.assertAllClose(embeddings_, expected_embeddings)
예제 #30
0
    def _build(self, input_graph):
        embed = snt.Embed(self.num_nodes, self.emb_size)
        linear = snt.Linear(self.hidden)
        norm = snt.BatchNorm()
        mlp = snt.nets.MLP([self.hidden, self.hidden], activate_final=False)
        gnn_attentions = []
        for _ in range(self.num_iters):
            gnn_attention = GraphAttention(self.hidden, self.num_heads)
            gnn_attentions.append(gnn_attention)
        final = snt.Linear(self.num_classes)

        node_features = input_graph.nodes
        if node_features[0] is not None:
            embs = tf.constant(node_features, dtype=tf.float32)
            embs = tf.nn.relu(linear(embs))
            embs = norm(mlp(embs) + embs, is_training=True)
        else:
            embs = embed(self.ids)

        for i in range(self.num_iters):
            embs = gnn_attentions[i](embs, input_graph)

        logits = final(embs)
        self.outputs = logits