def __init__(self, edge_output_size=None, node_output_size=None, global_output_size=None, name="EncodeProcessDecode"): super(EncodeProcessDecode, self).__init__(name=name) self._embed = snt.Embed(vocab_size=1267, embed_dim=LATENT_SIZE) self._embed_edge = snt.Embed(vocab_size=1, embed_dim=LATENT_SIZE) self._embed_global = snt.Embed(vocab_size=1, embed_dim=LATENT_SIZE) self._encoder = MLPGraphIndependent() self._core = MLPGraphNetwork() self._decoder = MLPGraphIndependent() # Transforms the outputs into the appropriate shapes. if edge_output_size is None: edge_fn = None else: edge_fn = lambda: snt.Linear(edge_output_size, name="edge_output") if node_output_size is None: node_fn = None else: node_fn = lambda: snt.Linear(node_output_size, name="node_output") if global_output_size is None: global_fn = None else: global_fn = lambda: snt.Linear(global_output_size, name="global_output") self._output_transform = modules.GraphIndependent( edge_fn, node_fn, global_fn)
def testDefaultVocabSize(self): embed_mod = snt.Embed(vocab_size=100, embed_dim=None, name="embed_small") self.assertEqual(embed_mod.embed_dim, 19) embed_mod = snt.Embed(vocab_size=1000000, embed_dim=None, name="embed_big") self.assertEqual(embed_mod.embed_dim, 190)
def testInvalidRegularizationParameters(self): regularizer = tf.contrib.layers.l1_regularizer(scale=0.5) with self.assertRaisesRegexp(KeyError, "Invalid regularizer keys.*"): snt.Embed(vocab_size=self._vocab_size, embed_dim=self._embed_dim, regularizers={"not_embeddings": regularizer}) err = "Regularizer for 'embeddings' is not a callable function" with self.assertRaisesRegexp(TypeError, err): snt.Embed(vocab_size=self._vocab_size, embed_dim=self._embed_dim, regularizers={"embeddings": tf.zeros([1, 2, 3])})
def setUp(self): super(EmbedTest, self).setUp() self._batch_size = 3 self._vocab_size = 7 self._embed_dim = 1 self._embed_mod = snt.Embed(vocab_size=self._vocab_size, embed_dim=self._embed_dim) self._embed_mod_dense = snt.Embed(vocab_size=self._vocab_size, embed_dim=self._embed_dim, densify_gradients=True) self._ids = np.asarray([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5], [2, 3, 4, 5, 6]])
def _construct_weights(self): """ Constructs the user/item memories and user/item external memory/outputs Also add the embedding lookups """ self.user_memory = snt.Embed(self.config.user_count, self.config.embed_size, initializers=self._embedding_initializers, name='MemoryEmbed') self.user_output = snt.Embed(self.config.user_count, self.config.embed_size, initializers=self._embedding_initializers, name='MemoryOutput') self.item_memory = snt.Embed(self.config.item_count, self.config.embed_size, initializers=self._embedding_initializers, name="ItemMemory") self._mem_layer = VariableLengthMemoryLayer( self.config.hops, self.config.embed_size, tf.nn.relu, initializers=self._hops_init, regularizers=self._regularizers, name='UserMemoryLayer') self._output_module = snt.Sequential([ DenseLayer(self.config.embed_size, True, tf.nn.relu, initializers=self._initializers, regularizers=self._regularizers, name='Layer'), snt.Linear(1, False, initializers=self._output_initializers, regularizers=self._regularizers, name='OutputVector'), tf.squeeze ]) # [batch, embedding size] self._cur_user = self.user_memory(self.input_users) self._cur_user_output = self.user_output(self.input_users) # Item memories a query self._cur_item = self.item_memory(self.input_items) self._cur_item_negative = self.item_memory(self.input_items_negative) # Share Embeddings self._cur_item_output = self._cur_item self._cur_item_output_negative = self._cur_item_negative
def __init__(self, edge_output_size=None, node_output_size=None, global_output_size=None, name="EncodeProcessDecode", VOCABSZ_SYMPTOM=1000): super(EncodeProcessDecode, self).__init__(name=name) ## embed layer self._embed_mod_symptomNode = snt.Embed(vocab_size=VOCABSZ_SYMPTOM, embed_dim=LATENT_SIZE, name='symptomNode') # self._embed_mod_ageGender=snt.Embed( # vocab_size=VOCABSZ_AGEGENDER, embed_dim=LATENT_SIZE/2,name='ageGender') self._embed_mod_global = snt.Embed(vocab_size=1, embed_dim=LATENT_SIZE_SMALL, name='global') self._embed_mod_edge = snt.Embed(vocab_size=1, embed_dim=LATENT_SIZE_SMALL, name='edge') self._encoder = MLPGraphIndependent(name='enc') self._core = MLPGraphNetwork(name='core') self._decoder = MLPGraphIndependent(name='dec') # Transforms the outputs into the appropriate shapes. if edge_output_size is None: edge_fn = None else: #edge_fn = lambda: snt.Linear(edge_output_size, name="edge_output") edge_fn = snt.Linear(edge_output_size, name="edge_output") if node_output_size is None: node_fn = None else: #node_fn = lambda: snt.Linear(node_output_size, name="node_output") node_fn = snt.Linear(node_output_size, name="node_output") if global_output_size is None: global_fn = None else: #global_fn = lambda: snt.Linear(global_output_size, name="global_output") global_fn = snt.Linear(global_output_size, name="global_output") #with self._enter_variable_scope(): with tf.compat.v1.variable_scope("out_transform"): #if 2>1: self._output_transform = modules.GraphIndependent( edge_fn, node_fn, global_fn, #name='out_transform' )
def _build(batch): """Build the loss sonnet module.""" # TODO(lmetz) these are dense updates.... so keeping this small for now. tokens = tf.minimum(batch["text"], tf.to_int64(tf.reshape(vocab_size - 1, [1, 1]))) embed = snt.Embed(vocab_size=vocab_size, embed_dim=embed_dim) embedded_tokens = embed(tokens) rnn = core_fn() bs = tokens.shape.as_list()[0] state = rnn.initial_state(bs, trainable=True) outputs, state = tf.nn.dynamic_rnn(rnn, embedded_tokens, initial_state=state) if aggregate_method == "last": last_output = outputs[:, -1] # grab the last output elif aggregate_method == "avg": last_output = tf.reduce_mean(outputs, [1]) # average over length else: raise ValueError("not supported aggregate_method") logits = snt.Linear(2)(last_output) loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2( labels=batch["label_onehot"], logits=logits) return tf.reduce_mean(loss_vec)
def _fn(batch): """Compute the loss from the given batch.""" # Shape is [bs, seq len, features] inp = batch["text"] mask = batch["mask"] embed = snt.Embed(vocab_size=256, embed_dim=embed_dim) embedded_chars = embed(inp) rnn = core_fn() bs = inp.shape.as_list()[0] state = rnn.initial_state(bs, trainable=True) outputs, state = tf.nn.dynamic_rnn(rnn, embedded_chars, initial_state=state) pred_logits = snt.BatchApply(snt.Linear(256))(outputs[:, :-1]) actual_tokens = inp[:, 1:] flat_s = [ pred_logits.shape[0] * pred_logits.shape[1], pred_logits.shape[2] ] f_pred_logits = tf.reshape(pred_logits, flat_s) f_actual_tokens = tf.reshape(actual_tokens, [flat_s[0]]) f_mask = tf.reshape(mask[:, 1:], [flat_s[0]]) loss_vec = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=f_actual_tokens, logits=f_pred_logits) total_loss = tf.reduce_sum(f_mask * loss_vec) mean_loss = total_loss / tf.reduce_sum(f_mask) return mean_loss
def common_embedding(features, num_types, type_embedding_dim): preexistance_feat = tf.expand_dims(tf.cast(features[:, 0], dtype=tf.float32), axis=1) type_embedder = snt.Embed(num_types, type_embedding_dim) type_embedding = type_embedder(tf.cast(features[:, 1], tf.int32)) return tf.concat([preexistance_feat, type_embedding], axis=1)
def compute_embedding(objective, config): objective = tf.cast(objective, dtype=tf.int32) # if objective is rank 3 -> one-hot encoded (batch_dim, sequence, vocab_size) # else objective is rank 2 -> need to compute embedding (batch_dim, sequence_of_embedding) if config["embedding_size"] > 0: embedded_obj = snt.Embed( vocab_size=config["max_size_vocab"] + 1, embed_dim=config["embedding_size"], densify_gradients=True, )(objective) else: # embedding_size == 0 -> onehot vector directly to lstm # Vocab_size or number of objective for this env if you don't use language depth = config.get("vocab_size", 2) embedded_obj = tf.one_hot(indices=objective, depth=depth, on_value=1.0, off_value=0.0, axis=1) if depth == 2: embedded_obj = tf.squeeze(embedded_obj, axis=2) return embedded_obj
def _instruction(self, instruction): # Split string. splitted = tf.string_split(instruction) dense = tf.sparse_tensor_to_dense(splitted, default_value='') length = tf.reduce_sum(tf.to_int32(tf.not_equal(dense, '')), axis=1) # To int64 hash buckets. Small risk of having collisions. Alternatively, a # vocabulary can be used. num_hash_buckets = 1000 buckets = tf.string_to_hash_bucket_fast(dense, num_hash_buckets) # Embed the instruction. Embedding size 20 seems to be enough. embedding_size = 20 embedding = snt.Embed(num_hash_buckets, embedding_size)(buckets) # Pad to make sure there is at least one output. padding = tf.to_int32(tf.equal(tf.shape(embedding)[1], 0)) embedding = tf.pad(embedding, [[0, 0], [0, padding], [0, 0]]) core = tf.contrib.rnn.LSTMBlockCell(64, name='language_lstm') output, _ = tf.nn.dynamic_rnn(core, embedding, length, dtype=tf.float32) # Return last output. return tf.reverse_sequence(output, length, seq_axis=1)[:, 0]
def __init__(self, vocab_size, embed_dim): super(PositionnalEmbedding, self).__init__(name="positional_embedding") with self._enter_variable_scope(): self.embed = snt.Embed(vocab_size=vocab_size, embed_dim=embed_dim, name="embedding")
def _build(batch): """Builds the sonnet module.""" # Shape is [batch size, sequence length] inp = batch["text"] # Clip the vocab to be at most vocab_size. inp = tf.minimum(inp, tf.to_int64(tf.reshape(cfg["vocab_size"] - 1, [1, 1]))) embed = snt.Embed(vocab_size=cfg["vocab_size"], embed_dim=cfg["embed_dim"]) embedded_chars = embed(inp) rnn = utils.get_rnn_core(cfg["core"]) batch_size = inp.shape.as_list()[0] state = rnn.initial_state(batch_size, trainable=cfg["trainable_init"]) outputs, state = tf.nn.dynamic_rnn(rnn, embedded_chars, initial_state=state) w_init = utils.get_initializer(cfg["w_init"]) pred_logits = snt.BatchApply( snt.Linear(cfg["vocab_size"], initializers={"w": w_init}))( outputs[:, :-1]) actual_output_tokens = inp[:, 1:] flat_s = [pred_logits.shape[0] * pred_logits.shape[1], pred_logits.shape[2]] flat_pred_logits = tf.reshape(pred_logits, flat_s) flat_actual_tokens = tf.reshape(actual_output_tokens, [flat_s[0]]) loss_vec = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=flat_actual_tokens, logits=flat_pred_logits) return tf.reduce_mean(loss_vec)
def _make_encoder(self): """Constructs an encoding for a single character ID.""" embed = snt.Embed(vocab_size=self.hparams.vocab_size + self.hparams.oov_buckets, embed_dim=self.hparams.embed_size) mlp = codec_mod.MLPObsEncoder(self.hparams) return codec_mod.EncoderSequence([embed, mlp], name="obs_encoder")
def common_embedding(features, num_types, type_embedding_dim): preexistance_feat = tf.expand_dims(tf.cast(features[:, 0], dtype=tf.float32), axis=1) type_embedder = snt.Embed(num_types, type_embedding_dim) norm = snt.LayerNorm() type_embedding = norm(type_embedder(tf.cast(features[:, 1], tf.int32))) tf.summary.histogram('type_embedding_histogram', type_embedding) return tf.concat([preexistance_feat, type_embedding], axis=1)
def _build(self, attribute_value): int_attribute_value = tf.cast(attribute_value, dtype=tf.int32) tf.summary.histogram('cat_attribute_value_histogram', int_attribute_value) embedding = snt.Embed(self._num_categories, self._attr_embedding_dim)(int_attribute_value) tf.summary.histogram('cat_embedding_histogram', embedding) return tf.squeeze(embedding, axis=1)
def testRegularizersInRegularizationLosses(self): regularizer = tf.contrib.layers.l1_regularizer(scale=0.5) embed = snt.Embed(vocab_size=self._vocab_size, embed_dim=self._embed_dim, regularizers={"embeddings": regularizer}) embed(tf.convert_to_tensor(self._ids)) regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) self.assertRegexpMatches(regularizers[0].name, ".*l1_regularizer.*")
def __init__(self, edge_output_size=None, node_output_size=None, global_output_size=None, name="EncodeProcessDecode"): super(EncodeProcessDecode, self).__init__(name=name) ###### input -> embed [batch,node] -> [batch node embed] self._embed_mod_symptomNode = snt.Embed(vocab_size=VOCABSZ_SYMPTOM, embed_dim=LATENT_SIZE, name='symptomNode') # self._embed_mod_ageGender=snt.Embed( # vocab_size=VOCABSZ_AGEGENDER, embed_dim=LATENT_SIZE/2,name='ageGender') self._embed_mod_global = snt.Embed(vocab_size=1, embed_dim=LATENT_SIZE_SMALL, name='global') self._embed_mod_edge = snt.Embed(vocab_size=1, embed_dim=LATENT_SIZE_SMALL, name='edge') self._encoder = MLPGraphIndependent() self._core = MLPGraphNetwork() self._decoder = MLPGraphIndependent() #self.step_mlp = get_mlp_model_steps() #self.mapToSym = map2sym() # Transforms the outputs into the appropriate shapes. if edge_output_size is None: edge_fn = None else: edge_fn = lambda: snt.Linear(edge_output_size, name="edge_output") if node_output_size is None: node_fn = None else: node_fn = lambda: snt.Linear(node_output_size, name="node_output") if global_output_size is None: global_fn = None else: global_fn = lambda: snt.Linear(global_output_size, name="global_output") with self._enter_variable_scope(): self._output_transform = modules.GraphIndependent( edge_fn, node_fn, global_fn)
def _build_embedding_layers(self, var_model, dimension_map, prefix): embedding_layers = [] for var_name in var_model.get_nominal(): embedding = snt.Embed( vocab_size=var_model.get_num_values(var_name) + 1, embed_dim=dimension_map[var_name], regularizers=self.regularizers, # TODO: densify_gradients? name=f"{prefix}-{var_name}") embedding_layers.append(embedding) return embedding_layers
def testInitializers(self): # Since all embeddings are initialized to zero, the extracted embeddings # should be as well. initializers = {"embeddings": tf.zeros_initializer()} embed_mod = snt.Embed(vocab_size=self._vocab_size, embed_dim=self._embed_dim, initializers=initializers) embeddings = embed_mod(tf.convert_to_tensor(self._ids)) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) embeddings_ = sess.run(embeddings) self.assertAllEqual(embeddings_, np.zeros_like(embeddings_))
def _construct_weights(self): """ Constructs the user/item memories and user/item external memory/outputs Also add the embedding lookups """ self.user_memory = snt.Embed(self.config.user_count, self.config.embed_size, initializers=self._embedding_initializers, regularizers=self._embedding_regularizers, name='MemoryEmbed') self.item_memory = snt.Embed(self.config.item_count, self.config.embed_size, initializers=self._embedding_initializers, regularizers=self._embedding_regularizers, name="ItemMemory") # [batch, embedding size] self._cur_user = self.user_memory(self.input_users) # Item memories a query self._cur_item = self.item_memory(self.input_items) self._cur_item_negative = self.item_memory(self.input_items_negative)
def __init__(self, vocab_size, embed_dim, max_word_length=None, initializer=None, trainable=True, name="char_embed"): super(CharEmbed, self).__init__(name=name) self._vocab_size = vocab_size# char_vocab.size self._embed_dim = embed_dim self._max_word_length = max_word_length self._initializers = init_dict(initializer, ['embeddings']) self._trainable = trainable with self._enter_variable_scope():# cuz in init (not build)... self._char_embedding = snt.Embed(vocab_size=self._vocab_size, embed_dim=self._embed_dim, trainable=self._trainable, #initializers={'embeddings':tf.constant_initializer(0.)}, name="internal_embed")
def __init__(self, vocab_size=None, embed_dim=None, initial_matrix=None, trainable=True, name="word_embed"): super(WordEmbed, self).__init__(name=name) self._vocab_size = vocab_size# word_vocab.size self._embed_dim = embed_dim self._trainable = trainable if initial_matrix: self._vocab_size = initial_matrix.shape[0] self._embed_dim = initial_matrix.shape[1] with self._enter_variable_scope():# cuz in init (not build)... self._embedding = snt.Embed(vocab_size=self._vocab_size, embed_dim=self._embed_dim, trainable=self._trainable, name="internal_embed")
def _build(self, sequence, sequence_length, is_train, previous_state=None): """ Add to graph. :param sequence: :param sequence_length: :param support: :param channel_state: :param previous_state: :param is_train: :return: """ batch_size = tf.shape(sequence)[0] # Convert sequence from token IDs to embeddings. if self.token_embedder is None: self.token_embedder = snt.Embed( vocab_size=self.params["embed.vocab_size"], embed_dim=self.params["embed.embed_size"]) sequence = self.token_embedder(sequence) # Get an representation of the read sequence. sequence_encoder = tf.nn.rnn_cell.LSTMCell( self.params["sequence_encoder.n_units"]) self.state_size = sum(sequence_encoder.state_size) if previous_state is not None: init_state = previous_state init_state = tf.split( init_state, len(nest.flatten(sequence_encoder.state_size)), axis=1) init_state = nest.pack_sequence_as( structure=sequence_encoder.state_size, flat_sequence=init_state) else: init_state = sequence_encoder.zero_state( batch_size=batch_size, dtype=tf.float32) _, current_state = tf.nn.dynamic_rnn( cell=sequence_encoder, inputs=sequence, sequence_length=sequence_length, initial_state=init_state) current_state = nest.flatten(current_state) current_state = tf.concat(current_state, axis=1) current_state = tf.reshape(current_state, [batch_size, self.state_size]) return current_state
def _build(batch): """Build the sonnet module. Args: batch: A dictionary with keys "label", "label_onehot", and "text" mapping to tensors. The "text" consists of int tokens. These tokens are truncated to the length of the vocab before performing an embedding lookup. Returns: Loss of the batch. """ vocab_size = cfg["vocab_size"] max_token = cfg["dataset"][1]["max_token"] if max_token: vocab_size = min(max_token, vocab_size) # Clip the max token to be vocab_size-1. tokens = tf.minimum(tf.to_int32(batch["text"]), tf.to_int32(tf.reshape(vocab_size - 1, [1, 1]))) embed = snt.Embed(vocab_size=vocab_size, embed_dim=cfg["embed_dim"]) embedded_tokens = embed(tokens) rnn = utils.get_rnn_core(cfg["core"]) batch_size = tokens.shape.as_list()[0] state = rnn.initial_state(batch_size, trainable=cfg["trainable_init"]) outputs, _ = tf.nn.dynamic_rnn(rnn, embedded_tokens, initial_state=state) if cfg["loss_compute"] == "last": rnn_output = outputs[:, -1] # grab the last output elif cfg["loss_compute"] == "avg": rnn_output = tf.reduce_mean(outputs, 1) # average over length elif cfg["loss_compute"] == "max": rnn_output = tf.reduce_max(outputs, 1) else: raise ValueError("Not supported loss_compute [%s]" % cfg["loss_compute"]) logits = snt.Linear(batch["label_onehot"].shape[1], initializers=init)(rnn_output) loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2( labels=batch["label_onehot"], logits=logits) return tf.reduce_mean(loss_vec)
def testExistingVocab(self): # Check that the module can be initialised with an existing vocabulary. existing = np.array([[1, 0, 0, 0], [1, 0, 1, 0], [0, 1, 0, 1]], dtype=np.int32) expected = np.array([[1, 0, 0, 0], [0, 1, 0, 1], [1, 0, 1, 0]], dtype=np.int32) true_vocab_size, true_embed_dim = existing.shape inputs = tf.constant(np.array([0, 2, 1]), dtype=tf.int32) embed_mod = snt.Embed(existing_vocab=existing) embeddings = embed_mod(inputs) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) embeddings_ = sess.run(embeddings) self.assertAllClose(embeddings_, expected) self.assertEqual(embed_mod.vocab_size, true_vocab_size) self.assertEqual(embed_mod.embed_dim, true_embed_dim)
def build_modules(is_training, vocab_size): """Construct the modules used in the graph.""" # Construct the custom getter which implements Bayes by Backprop. if is_training: estimator_mode = tf.constant(bbb.EstimatorModes.sample) else: estimator_mode = tf.constant(bbb.EstimatorModes.mean) lstm_bbb_custom_getter = bbb.bayes_by_backprop_getter( posterior_builder=lstm_posterior_builder, prior_builder=custom_scale_mixture_prior_builder, kl_builder=bbb.stochastic_kl_builder, sampling_mode_tensor=estimator_mode, ) non_lstm_bbb_custom_getter = bbb.bayes_by_backprop_getter( posterior_builder=non_lstm_posterior_builder, prior_builder=custom_scale_mixture_prior_builder, kl_builder=bbb.stochastic_kl_builder, sampling_mode_tensor=estimator_mode, ) embed_layer = snt.Embed( vocab_size=vocab_size, embed_dim=FLAGS.embedding_size, custom_getter=non_lstm_bbb_custom_getter, name="input_embedding", ) cores = [ snt.LSTM( FLAGS.hidden_size, custom_getter=lstm_bbb_custom_getter, forget_bias=0.0, name="lstm_layer_{}".format(i), ) for i in six.moves.range(FLAGS.n_layers) ] rnn_core = snt.DeepRNN(cores, skip_connections=False, name="deep_lstm_core") # Do BBB on weights but not biases of output layer. output_linear = snt.Linear(vocab_size, custom_getter={"w": non_lstm_bbb_custom_getter}) return embed_layer, rnn_core, output_linear
def testPartitioners(self): # Partition embeddings such that there's one variable per vocabulary entry. partitioners = { "embeddings": tf.variable_axis_size_partitioner(4 * self._embed_dim) } embed_mod = snt.Embed(vocab_size=self._vocab_size, embed_dim=self._embed_dim, partitioners=partitioners) embeddings = embed_mod(tf.convert_to_tensor(self._ids)) self.assertEqual(type(embed_mod.embeddings), variables.PartitionedVariable) self.assertEqual(len(embed_mod.embeddings), self._vocab_size) # Ensure that tf.nn.embedding_lookup() plays nicely with embedding # variables. with self.test_session() as sess: sess.run(tf.global_variables_initializer()) sess.run(embeddings)
def testComputation(self): # Initialize each embedding to its index. Thus, the lookup ids are the same # as the embeddings themselves. initializers = { "embeddings": tf.constant_initializer([[0], [1], [2], [3], [4], [5], [6]], dtype=tf.float32) } embed_mod = snt.Embed(vocab_size=self._vocab_size, embed_dim=self._embed_dim, initializers=initializers) embeddings = embed_mod(tf.convert_to_tensor(self._ids)) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) embeddings_ = sess.run(embeddings) expected_embeddings = np.reshape(self._ids, newshape=list(self._ids.shape) + [self._embed_dim]) self.assertAllClose(embeddings_, expected_embeddings)
def _build(self, input_graph): embed = snt.Embed(self.num_nodes, self.emb_size) linear = snt.Linear(self.hidden) norm = snt.BatchNorm() mlp = snt.nets.MLP([self.hidden, self.hidden], activate_final=False) gnn_attentions = [] for _ in range(self.num_iters): gnn_attention = GraphAttention(self.hidden, self.num_heads) gnn_attentions.append(gnn_attention) final = snt.Linear(self.num_classes) node_features = input_graph.nodes if node_features[0] is not None: embs = tf.constant(node_features, dtype=tf.float32) embs = tf.nn.relu(linear(embs)) embs = norm(mlp(embs) + embs, is_training=True) else: embs = embed(self.ids) for i in range(self.num_iters): embs = gnn_attentions[i](embs, input_graph) logits = final(embs) self.outputs = logits