Beispiel #1
0
            def categorical_loss(labels, logits):
                # labels come as a batch of classes [[1,2],[3,4]] -> [1,3,2,4] time steps are ordered to match logits
                labels = tx.Transpose(labels)
                labels = tx.Reshape(labels, [-1])
                labels = tx.dense_one_hot(labels, num_cols=vocab_size)
                loss = tx.categorical_cross_entropy(labels=labels,
                                                    logits=logits)

                return tf.reduce_mean(loss)
Beispiel #2
0
    def __init__(
        self,
        run_inputs,
        label_inputs,
        eval_label_input,
        ctx_size,
        k_dim,
        ri_tensor_input,
        embed_dim,
        h_dim,
        embed_init=tx.random_uniform(minval=-0.01, maxval=0.01),
        num_h=1,
        h_activation=tx.relu,
        h_init=tx.he_normal_init,
        use_dropout=False,
        embed_dropout=False,
        keep_prob=0.95,
        l2_loss=False,
        l2_loss_coef=1e-5,
        f_init=tx.random_uniform(minval=-0.01, maxval=0.01),
        use_nce=False,
        nce_samples=2,
        nce_noise_amount=0.1,
        noise_input=None,
    ):

        self.embed_dim = embed_dim

        var_reg = []

        # ===============================================
        # RUN GRAPH
        # ===============================================

        with tf.name_scope("run"):

            feature_lookup = tx.Lookup(run_inputs,
                                       seq_size=ctx_size,
                                       lookup_shape=[k_dim, embed_dim],
                                       weight_init=embed_init,
                                       name="lookup")

            self.embeddings = feature_lookup
            var_reg.append(feature_lookup.weights)
            feature_lookup = feature_lookup.as_concat()
            # ===========================================================
            with tf.name_scope("cache_embeddings"):
                # ris = [sign_index.get_ri(sign_index.get_sign(i)) for i in range(len(sign_index))]
                # self.all_ris = ris_to_sp_tensor_value(ri_seq=ris,
                #                                      dim=sign_index.generator.dim,
                #                                      all_positive=not sign_index.generator.symmetric)

                all_embeddings = tx.Linear(
                    ri_tensor_input,
                    n_units=self.embed_dim,
                    shared_weights=self.embeddings.weights,
                    bias=False,
                    name='all_features')

                # caches all embedding computation for run/eval
                self.all_embeddings = tx.VariableLayer(all_embeddings,
                                                       trainable=False)
            # ===========================================================
            last_layer = feature_lookup
            h_layers = []
            for i in range(num_h):
                hi = tx.FC(last_layer,
                           n_units=h_dim,
                           activation=h_activation,
                           weight_init=h_init,
                           name="h_{i}".format(i=i))
                h_layers.append(hi)
                last_layer = hi
                var_reg.append(hi.linear.weights)

            self.h_layers = h_layers

            # feature prediction for Energy-Based Model

            f_prediction = tx.Linear(last_layer,
                                     embed_dim,
                                     f_init,
                                     bias=True,
                                     name="f_predict")
            var_reg.append(f_prediction.weights)

            # RI DECODING ===============================================
            # shape is (?,?) because batch size is unknown and vocab size is unknown
            # when we build the graph
            run_logits = tx.Linear(f_prediction,
                                   n_units=None,
                                   shared_weights=self.all_embeddings.variable,
                                   transpose_weights=True,
                                   bias=False,
                                   name="logits")

            # ===========================================================
            embed_prob = tx.Activation(run_logits,
                                       tx.softmax,
                                       name="run_output")

        # ===============================================
        # TRAIN GRAPH
        # ===============================================
        with tf.name_scope("train"):
            if use_dropout and embed_dropout:
                feature_lookup = feature_lookup.reuse_with(run_inputs)
                last_layer = tx.Dropout(feature_lookup, probability=keep_prob)
            else:
                last_layer = feature_lookup

            # add dropout between each layer
            for layer in h_layers:
                h = layer.reuse_with(last_layer)
                if use_dropout:
                    h = tx.Dropout(h, probability=keep_prob)
                last_layer = h

            f_prediction = f_prediction.reuse_with(last_layer)

            train_logits = run_logits.reuse_with(f_prediction,
                                                 name="train_logits")
            train_embed_prob = tx.Activation(train_logits,
                                             tx.softmax,
                                             name="train_output")

            #  convert labels to random indices
            model_prediction = f_prediction.tensor

            if use_nce:
                train_loss = tx.sparse_cnce_loss(
                    label_features=label_inputs.tensor,
                    noise_features=noise_input.tensor,
                    model_prediction=model_prediction,
                    weights=feature_lookup.weights,
                    num_samples=nce_samples,
                    noise_ratio=nce_noise_amount)
            else:
                one_hot_dense = tx.dense_one_hot(
                    column_indices=label_inputs[0].tensor,
                    num_cols=label_inputs[1].tensor)
                train_loss = tx.categorical_cross_entropy(
                    one_hot_dense, train_logits.tensor)

                train_loss = tf.reduce_mean(train_loss)

            if l2_loss:
                losses = [tf.nn.l2_loss(var) for var in var_reg]
                train_loss = train_loss + l2_loss_coef * tf.add_n(losses)

        # ===============================================
        # EVAL GRAPH
        # ===============================================
        with tf.name_scope("eval"):
            one_hot_dense = tx.dense_one_hot(
                column_indices=eval_label_input[0].tensor,
                num_cols=label_inputs[1].tensor)
            train_loss = tx.categorical_cross_entropy(one_hot_dense,
                                                      train_logits.tensor)
            eval_loss = tx.categorical_cross_entropy(one_hot_dense,
                                                     run_logits.tensor)
            eval_loss = tf.reduce_mean(eval_loss)

        if use_nce:
            train_loss_in = [label_inputs, noise_input]
        else:
            train_loss_in = label_inputs

        # BUILD MODEL
        super().__init__(run_inputs=run_inputs,
                         run_outputs=embed_prob,
                         train_inputs=run_inputs,
                         train_outputs=train_embed_prob,
                         eval_inputs=run_inputs,
                         eval_outputs=embed_prob,
                         train_out_loss=train_loss,
                         train_in_loss=train_loss_in,
                         eval_out_score=eval_loss,
                         eval_in_score=eval_label_input,
                         update_inputs=ri_tensor_input)
Beispiel #3
0
    def __init__(self,
                 ctx_size,
                 vocab_size,
                 embed_dim,
                 embed_init=tx.random_uniform(minval=-0.01, maxval=0.01),
                 x_to_f_init=tx.random_uniform(minval=-0.01, maxval=0.01),
                 logit_init=tx.random_uniform(minval=-0.01, maxval=0.01),
                 embed_share=True,
                 use_gate=True,
                 use_hidden=False,
                 h_dim=100,
                 h_activation=tx.elu,
                 h_init=tx.he_normal_init(),
                 h_to_f_init=tx.random_uniform(minval=-0.01, maxval=0.01),
                 use_dropout=True,
                 embed_dropout=False,
                 keep_prob=0.95,
                 l2_loss=False,
                 l2_loss_coef=1e-5,
                 use_nce=False,
                 nce_samples=100):

        # GRAPH INPUTS
        run_inputs = tx.Input(ctx_size, dtype=tf.int32, name="input")
        loss_inputs = tx.Input(n_units=1, dtype=tf.int32, name="target")
        eval_inputs = loss_inputs

        # RUN GRAPH
        # if I create a scope here the Tensorboard graph will be a mess to read
        # because it groups everything by nested scope names
        # instead if I choose to create different scopes for train and eval only
        # the graph stays readable because it allows us to use the same names
        # under different scopes while still sharing variables
        var_reg = []
        with tf.name_scope("run"):
            feature_lookup = tx.Lookup(run_inputs,
                                       ctx_size, [vocab_size, embed_dim],
                                       embed_init,
                                       name="lookup")
            var_reg.append(feature_lookup.weights)
            feature_lookup = feature_lookup.as_concat()

            if use_gate or use_hidden:
                hl = tx.Linear(feature_lookup,
                               h_dim,
                               h_init,
                               bias=True,
                               name="h_linear")
                ha = tx.Activation(hl, h_activation, name="h_activation")
                h = tx.Compose(hl, ha, name="hidden")
                var_reg.append(hl.weights)

            features = feature_lookup
            if use_gate:
                gate_w = tx.Linear(h, ctx_size, bias=True)
                gate = tx.Gate(features, gate_input=gate_w)

                # gate = tx.Module([h, features], gate)

                features = gate
                var_reg.append(gate_w.weights)

            x_to_f = tx.Linear(features,
                               embed_dim,
                               x_to_f_init,
                               bias=True,
                               name="x_to_f")
            var_reg.append(x_to_f.weights)
            f_prediction = x_to_f

            if use_hidden:
                h_to_f = tx.Linear(h,
                                   embed_dim,
                                   h_to_f_init,
                                   bias=True,
                                   name="h_to_f")
                var_reg.append(h_to_f.weights)
                f_prediction = tx.Add(x_to_f, h_to_f, name="f_predicted")

            # RI DECODING ===============================================
            shared_weights = tf.transpose(
                feature_lookup.weights) if embed_share else None
            logit_init = logit_init if not embed_share else None
            run_logits = tx.Linear(f_prediction,
                                   vocab_size,
                                   logit_init,
                                   shared_weights,
                                   bias=True,
                                   name="logits")
            if not embed_share:
                var_reg.append(run_logits.weights)
            y_prob = tx.Activation(run_logits, tx.softmax)

        # TRAIN GRAPH ===============================================
        with tf.name_scope("train"):
            if use_dropout and embed_dropout:
                feature_lookup = feature_lookup.reuse_with(run_inputs)
                features = tx.Dropout(feature_lookup, probability=keep_prob)
            else:
                features = feature_lookup

            if use_gate or use_hidden:
                if use_dropout:
                    h = h.reuse_with(features)
                    h = tx.Dropout(h, probability=keep_prob)

                if use_gate:
                    gate_w = gate_w.reuse_with(h)
                    features = gate.reuse_with(layer=features,
                                               gate_input=gate_w)

                f_prediction = x_to_f.reuse_with(features)

                if use_hidden:
                    h_to_f = h_to_f.reuse_with(h)
                    if use_dropout:
                        h_to_f = tx.Dropout(h_to_f, probability=keep_prob)
                    f_prediction = tx.Add(f_prediction, h_to_f)
            else:
                f_prediction = f_prediction.reuse_with(features)

            train_logits = run_logits.reuse_with(f_prediction)

            if use_nce:
                # uniform gets good enough results if enough samples are used
                # but we can load the empirical unigram distribution
                # or learn the unigram distribution during training
                sampled_values = uniform_sampler(loss_inputs.tensor, 1,
                                                 nce_samples, True, vocab_size)
                train_loss = tf.nn.nce_loss(weights=tf.transpose(
                    train_logits.weights),
                                            biases=train_logits.bias,
                                            inputs=f_prediction.tensor,
                                            labels=loss_inputs.tensor,
                                            num_sampled=nce_samples,
                                            num_classes=vocab_size,
                                            num_true=1,
                                            sampled_values=sampled_values)
            else:
                one_hot = tx.dense_one_hot(column_indices=loss_inputs.tensor,
                                           num_cols=vocab_size)
                train_loss = tx.categorical_cross_entropy(
                    one_hot, train_logits.tensor)

            train_loss = tf.reduce_mean(train_loss)

            if l2_loss:
                losses = [tf.nn.l2_loss(var) for var in var_reg]
                train_loss = train_loss + l2_loss_coef * tf.add_n(losses)

        # EVAL GRAPH ===============================================
        with tf.name_scope("eval"):
            one_hot = tx.dense_one_hot(column_indices=eval_inputs.tensor,
                                       num_cols=vocab_size)
            eval_loss = tx.categorical_cross_entropy(one_hot,
                                                     run_logits.tensor)
            eval_loss = tf.reduce_mean(eval_loss)

        # SETUP MODEL CONTAINER ====================================
        super().__init__(run_inputs=run_inputs,
                         run_outputs=y_prob,
                         train_inputs=run_inputs,
                         train_outputs=y_prob,
                         eval_inputs=run_inputs,
                         eval_outputs=y_prob,
                         train_out_loss=train_loss,
                         train_in_loss=loss_inputs,
                         eval_out_score=eval_loss,
                         eval_in_score=eval_inputs)
Beispiel #4
0
                           name="all_features",
                           bias=False)

# dot product of f_predicted . all_embeddings with bias for each target word
run_logits = tx.Linear(feature_predict,
                       vocab_size,
                       shared_weights=all_embeddings.tensor,
                       transpose_weights=True,
                       bias=False,
                       name="logits")

embed_prob = tx.Activation(run_logits, tx.softmax, name="run_output")

one_hot = tx.dense_one_hot(column_indices=input_labels.tensor,
                           num_cols=vocab_size)
val_loss = tx.categorical_cross_entropy(one_hot, run_logits.tensor)
val_loss = tf.reduce_mean(val_loss)

# *************************************
#   Testing adaptive noise
# *************************************
#TODO I need to test the infinite vocab scenario where we try to generate
#RIs directly we can use sparsemax in that case
noise_logits = tx.Linear(lookup, vocab_size, bias=True)
adaptive_noise = tx.Activation(noise_logits, tx.softmax)

# adaptive_noise = tx.sample_sigmoid_from_logits(noise_logits.tensor, n=1)
# adaptive_noise = tx.TensorLayer(adaptive_noise, n_units=k)
# adaptive_noise = tx.to_sparse(adaptive_noise)

# *************************************
Beispiel #5
0
# type of rnn cell
cell = tf.nn.rnn_cell.LSTMCell(num_units=n_hidden, state_is_tuple=True)
val, state = tf.nn.dynamic_rnn(cell, lookup_to_seq, dtype=tf.float32)

val = tf.transpose(val, [1, 0, 2])

# last = tf.gather(val, int(val.get_shape()[0]) - 1)
last = val[-1]

lstm_out = tx.TensorLayer(last, n_hidden)
logits = tx.Linear(lstm_out, vocab_size, bias=True)
out = tx.Activation(logits, tx.softmax)

labels = tx.dense_one_hot(loss_inputs.tensor, vocab_size)
loss = tf.reduce_mean(tx.categorical_cross_entropy(labels=labels, logits=logits.tensor))

# setup optimizer
optimizer = tx.AMSGrad(learning_rate=0.01)

model = tx.Model(run_inputs=in_layer, run_outputs=out,
                 train_inputs=in_layer, train_outputs=out,
                 train_in_loss=loss_inputs, train_out_loss=loss,
                 eval_out_score=loss, eval_in_score=loss_inputs)

print(model.feedable_train())

runner = tx.ModelRunner(model)
runner.config_optimizer(optimizer)

runner.init_vars()
Beispiel #6
0
 def categorical_loss(labels, logits):
     labels = tx.dense_one_hot(column_indices=labels, num_cols=vocab_size)
     loss = tx.categorical_cross_entropy(labels=labels, logits=logits)
     # loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels,logits=logits)
     return tf.reduce_mean(loss)
Beispiel #7
0
loss_inputs = tx.Input(1, dtype=tf.int32)
in_layer = tx.Input(seq_size, dtype=tf.int32)

lookup = tx.Lookup(in_layer, seq_size=seq_size, lookup_shape=feature_shape)
# [batch x seq_size * feature_shape[1]]

h = tx.Linear(lookup, n_hidden, bias=True)
ha = tx.Activation(h, tx.elu)
h = tx.Compose(h, ha)

logits = tx.Linear(h, vocab_size, bias=True)
out = tx.Activation(logits, tx.softmax)

labels = tx.dense_one_hot(loss_inputs.tensor, vocab_size)
loss = tf.reduce_mean(
    tx.categorical_cross_entropy(labels=labels, logits=logits.tensor))

# setup optimizer
optimizer = tx.AMSGrad(learning_rate=0.01)

model = tx.Model(run_inputs=in_layer,
                 run_outputs=out,
                 train_inputs=in_layer,
                 train_outputs=out,
                 train_in_loss=loss_inputs,
                 train_out_loss=loss,
                 eval_out_score=loss,
                 eval_in_score=loss_inputs)

print(model.feedable_train())
Beispiel #8
0
    def __init__(self,
                 ctx_size,
                 vocab_size,
                 k_dim,
                 ri_tensor: RandomIndexTensor,
                 embed_dim,
                 embed_init=tx.random_uniform(minval=-0.01, maxval=0.01),
                 x_to_f_init=tx.random_uniform(minval=-0.01, maxval=0.01),
                 logit_init=tx.random_uniform(minval=-0.01, maxval=0.01),
                 embed_share=True,
                 logit_bias=False,
                 use_gate=True,
                 use_hidden=False,
                 h_dim=100,
                 h_activation=tx.elu,
                 h_init=tx.he_normal_init(),
                 h_to_f_init=tx.random_uniform(minval=-0.01, maxval=0.01),
                 use_dropout=True,
                 embed_dropout=False,
                 keep_prob=0.95,
                 l2_loss=False,
                 l2_loss_coef=1e-5):

        # GRAPH INPUTS
        run_inputs = tx.Input(ctx_size, dtype=tf.int32, name="input")
        loss_inputs = tx.Input(n_units=1, dtype=tf.int32, name="target")
        eval_inputs = loss_inputs

        # RUN GRAPH =====================================================
        var_reg = []
        with tf.name_scope("run"):
            # RI ENCODING ===============================================
            # convert ids to ris gather a set of random indexes based on the ids in a sequence

            # ri_layer = tx.TensorLayer(ri_tensor, n_units=k_dim)
            # ri_inputs = tx.gather_sparse(ri_layer.tensor, run_inputs.tensor)
            with tf.name_scope("ri_encode"):
                # used to compute logits
                if isinstance(ri_tensor, RandomIndexTensor):
                    ri_layer = tx.TensorLayer(ri_tensor.to_sparse_tensor(),
                                              k_dim)

                    ri_inputs = ri_tensor.gather(run_inputs.tensor)
                    ri_inputs = ri_inputs.to_sparse_tensor()
                    ri_inputs = tx.TensorLayer(ri_inputs, k_dim)
                else:
                    ri_layer = tx.TensorLayer(ri_tensor, k_dim)
                    ri_inputs = tx.gather_sparse(ri_layer.tensor,
                                                 run_inputs.tensor)
                    ri_inputs = tx.TensorLayer(ri_inputs, k_dim)

            # use those sparse indexes to lookup a set of features based on the ri values
            feature_lookup = tx.Lookup(ri_inputs,
                                       ctx_size, [k_dim, embed_dim],
                                       embed_init,
                                       name="lookup")
            var_reg.append(feature_lookup.weights)
            feature_lookup = feature_lookup.as_concat()
            # ===========================================================

            if use_gate or use_hidden:
                hl = tx.Linear(feature_lookup,
                               h_dim,
                               h_init,
                               bias=True,
                               name="h_linear")
                ha = tx.Activation(hl, h_activation, name="h_activation")
                h = tx.Compose(hl, ha, name="hidden")
                var_reg.append(hl.weights)

            features = feature_lookup
            if use_gate:
                features = tx.Gate(features, ctx_size, gate_input=h)
                gate = features
                var_reg.append(features.gate_weights)

            x_to_f = tx.Linear(features,
                               embed_dim,
                               x_to_f_init,
                               bias=True,
                               name="x_to_f")
            var_reg.append(x_to_f.weights)
            f_prediction = x_to_f

            if use_hidden:
                h_to_f = tx.Linear(h,
                                   embed_dim,
                                   h_to_f_init,
                                   bias=True,
                                   name="h_to_f")
                var_reg.append(h_to_f.weights)
                f_prediction = tx.Add(x_to_f, h_to_f, name="f_predicted")

            # RI DECODING ===============================================
            shared_weights = feature_lookup.weights if embed_share else None
            logit_init = logit_init if not embed_share else None
            # embedding feature vectors for all words: shape [vocab_size, embed_dim]
            # later, for NCE we don't need to get all the features

            all_embeddings = tx.Linear(ri_layer,
                                       embed_dim,
                                       logit_init,
                                       shared_weights,
                                       name="logits",
                                       bias=False)

            # dot product of f_predicted . all_embeddings with bias for each target word

            run_logits = tx.Linear(f_prediction,
                                   n_units=vocab_size,
                                   shared_weights=all_embeddings.tensor,
                                   transpose_weights=True,
                                   bias=logit_bias)

            if not embed_share:
                var_reg.append(all_embeddings.weights)

            # ===========================================================
            run_embed_prob = tx.Activation(run_logits, tx.softmax)

        # TRAIN GRAPH ===================================================
        with tf.name_scope("train"):
            if use_dropout and embed_dropout:
                feature_lookup = feature_lookup.reuse_with(ri_inputs)
                features = tx.Dropout(feature_lookup, probability=keep_prob)
            else:
                features = feature_lookup

            if use_gate or use_hidden:
                if use_dropout:
                    h = h.reuse_with(features)
                    h = tx.Dropout(h, probability=keep_prob)

                if use_gate:
                    features = gate.reuse_with(features, gate_input=h)

                f_prediction = x_to_f.reuse_with(features)

                if use_hidden:
                    h_to_f = h_to_f.reuse_with(h)
                    if use_dropout:
                        h_to_f = tx.Dropout(h_to_f, probability=keep_prob)
                    f_prediction = tx.Add(f_prediction, h_to_f)
            else:
                f_prediction = f_prediction.reuse_with(features)

            # we already define all_embeddings from which these logits are computed before so this should be ok
            train_logits = run_logits.reuse_with(f_prediction)

            train_embed_prob = tx.Activation(train_logits,
                                             tx.softmax,
                                             name="train_output")

            one_hot = tx.dense_one_hot(column_indices=loss_inputs.tensor,
                                       num_cols=vocab_size)
            train_loss = tx.categorical_cross_entropy(one_hot,
                                                      train_logits.tensor)

            train_loss = tf.reduce_mean(train_loss)

            if l2_loss:
                losses = [tf.nn.l2_loss(var) for var in var_reg]
                train_loss = train_loss + l2_loss_coef * tf.add_n(losses)

        # EVAL GRAPH ===============================================
        with tf.name_scope("eval"):
            one_hot = tx.dense_one_hot(column_indices=eval_inputs.tensor,
                                       num_cols=vocab_size)
            eval_loss = tx.categorical_cross_entropy(one_hot,
                                                     run_logits.tensor)
            eval_loss = tf.reduce_mean(eval_loss)

        # SETUP MODEL CONTAINER ====================================
        super().__init__(run_inputs=run_inputs,
                         run_outputs=run_embed_prob,
                         train_inputs=run_inputs,
                         train_outputs=train_embed_prob,
                         eval_inputs=run_inputs,
                         eval_outputs=run_embed_prob,
                         train_out_loss=train_loss,
                         train_in_loss=loss_inputs,
                         eval_out_score=eval_loss,
                         eval_in_score=eval_inputs)
Beispiel #9
0
    def __init__(self,
                 ctx_size,
                 vocab_size,
                 k_dim,
                 s_active,
                 ri_tensor,
                 embed_dim,
                 h_dim,
                 embed_init=tx.random_uniform(minval=-0.01, maxval=0.01),
                 logit_init=tx.random_uniform(minval=-0.01, maxval=0.01),
                 num_h=1,
                 h_activation=tx.relu,
                 h_init=tx.he_normal_init,
                 use_dropout=False,
                 embed_dropout=False,
                 keep_prob=0.95,
                 l2_loss=False,
                 l2_loss_coef=1e-5,
                 f_init=tx.random_uniform(minval=-0.01, maxval=0.01),
                 embed_share=True,
                 logit_bias=False,
                 use_nce=False,
                 nce_samples=100,
                 noise_level=0.1):

        run_inputs = tx.Input(ctx_size, dtype=tf.int32)
        loss_inputs = tx.Input(n_units=1, dtype=tf.int64)
        eval_inputs = loss_inputs

        if run_inputs.dtype != tf.int32 and run_inputs.dtype != tf.int64:
            raise TypeError(
                "Invalid dtype for input: expected int32 or int64, got {}".
                format(run_inputs.dtype))

        if num_h < 0:
            raise ValueError("num hidden should be >= 0")

        # ===============================================
        # RUN GRAPH
        # ===============================================
        var_reg = []

        with tf.name_scope("run"):
            # RI ENCODING ===============================================
            # convert ids to ris gather a set of random indexes based on the ids in a sequence
            # ri_layer = tx.TensorLayer(ri_tensor, n_units=k_dim)
            # ri_inputs = tx.gather_sparse(ri_layer.tensor, run_inputs.tensor)
            # ri_inputs = tx.TensorLayer(ri_inputs, n_units=k_dim)
            with tf.name_scope("ri_encode"):
                if isinstance(ri_tensor, RandomIndexTensor):
                    ri_tensor = ri_tensor
                    ri_layer = tx.TensorLayer(ri_tensor.to_sparse_tensor(),
                                              k_dim,
                                              shape=[vocab_size, k_dim])

                    ri_inputs = ri_tensor.gather(run_inputs.tensor)
                    ri_inputs = ri_inputs.to_sparse_tensor()
                    ri_inputs = tx.TensorLayer(
                        ri_inputs,
                        k_dim,
                        shape=[ri_inputs.get_shape()[0], k_dim])
                # ri_tensor is a sparse tensor
                else:
                    raise TypeError(
                        "please supply RandomIndexTensor instead of sparse Tensor"
                    )
                    # ri_layer = tx.TensorLayer(ri_tensor, k_dim)
                    # ri_inputs = tx.gather_sparse(ri_layer.tensor, run_inputs.tensor)
                    # ri_inputs = tx.TensorLayer(ri_inputs, k_dim)

            feature_lookup = tx.Lookup(ri_inputs,
                                       ctx_size, [k_dim, embed_dim],
                                       embed_init,
                                       name="lookup")
            self.embeddings = feature_lookup
            var_reg.append(feature_lookup.weights)
            feature_lookup = feature_lookup.as_concat()
            # ===========================================================

            last_layer = feature_lookup
            h_layers = []
            for i in range(num_h):
                h_i = tx.Linear(last_layer,
                                h_dim,
                                h_init,
                                bias=True,
                                name="h_{i}_linear".format(i=i))
                h_a = tx.Activation(h_i, h_activation)
                h = tx.Compose(h_i, h_a, name="h_{i}".format(i=i))
                h_layers.append(h)
                last_layer = h
                var_reg.append(h_i.weights)

            self.h_layers = h_layers

            # feature prediction for Energy-Based Model

            f_prediction = tx.Linear(last_layer,
                                     embed_dim,
                                     f_init,
                                     bias=True,
                                     name="f_predict")
            var_reg.append(f_prediction.weights)

            # RI DECODING ===============================================

            # Shared Embeddings
            if embed_share:
                shared_weights = feature_lookup.weights if embed_share else None
                logit_init = logit_init if not embed_share else None

                # ri_dense = tx.ToDense(ri_layer)
                all_embeddings = tx.Linear(ri_layer,
                                           embed_dim,
                                           logit_init,
                                           shared_weights,
                                           name="all_features",
                                           bias=False)

                # dot product of f_predicted . all_embeddings with bias for each target word
                run_logits = tx.Linear(f_prediction,
                                       vocab_size,
                                       shared_weights=all_embeddings.tensor,
                                       transpose_weights=True,
                                       bias=logit_bias,
                                       name="logits")
            else:
                run_logits = tx.Linear(f_prediction,
                                       vocab_size,
                                       bias=logit_bias,
                                       name="logits")

            if not embed_share:
                var_reg.append(run_logits.weights)
            # ===========================================================

            embed_prob = tx.Activation(run_logits,
                                       tx.softmax,
                                       name="run_output")

        # ===============================================
        # TRAIN GRAPH
        # ===============================================
        with tf.name_scope("train"):
            if use_dropout and embed_dropout:
                feature_lookup = feature_lookup.reuse_with(ri_inputs)
                last_layer = tx.Dropout(feature_lookup, probability=keep_prob)
            else:
                last_layer = feature_lookup

            # add dropout between each layer
            for layer in h_layers:
                h = layer.reuse_with(last_layer)
                if use_dropout:
                    h = tx.Dropout(h, probability=keep_prob)
                last_layer = h

            f_prediction = f_prediction.reuse_with(last_layer)

            train_logits = run_logits.reuse_with(f_prediction,
                                                 name="train_logits")
            train_embed_prob = tx.Activation(train_logits,
                                             tx.softmax,
                                             name="train_output")

            if use_nce:
                # labels
                labels = loss_inputs.tensor

                #  convert labels to random indices
                def labels_to_ri(x):
                    random_index_tensor = ri_tensor.gather(x)
                    sp_features = random_index_tensor.to_sparse_tensor()
                    return sp_features

                model_prediction = f_prediction.tensor

                train_loss = tx.sparse_cnce_loss(
                    label_features=labels,
                    model_prediction=model_prediction,
                    weights=feature_lookup.weights,
                    noise_ratio=noise_level,
                    num_samples=nce_samples,
                    labels_to_sparse_features=labels_to_ri)

            else:
                one_hot = tx.dense_one_hot(column_indices=loss_inputs.tensor,
                                           num_cols=vocab_size)
                train_loss = tx.categorical_cross_entropy(
                    one_hot, train_logits.tensor)

                train_loss = tf.reduce_mean(train_loss)

            if l2_loss:
                losses = [tf.nn.l2_loss(var) for var in var_reg]
                train_loss = train_loss + l2_loss_coef * tf.add_n(losses)

        # ===============================================
        # EVAL GRAPH
        # ===============================================
        with tf.name_scope("eval"):
            one_hot = tx.dense_one_hot(column_indices=eval_inputs.tensor,
                                       num_cols=vocab_size)
            eval_loss = tx.categorical_cross_entropy(one_hot,
                                                     run_logits.tensor)
            eval_loss = tf.reduce_mean(eval_loss)

        # BUILD MODEL
        super().__init__(run_inputs=run_inputs,
                         run_outputs=embed_prob,
                         train_inputs=run_inputs,
                         train_outputs=train_embed_prob,
                         eval_inputs=run_inputs,
                         eval_outputs=embed_prob,
                         train_out_loss=train_loss,
                         train_in_loss=loss_inputs,
                         eval_out_score=eval_loss,
                         eval_in_score=eval_inputs)