コード例 #1
0
ファイル: rmc_att_topic.py プロジェクト: federicoBetti/RelGAN
    def __init__(self,
                 x_real,
                 temperature,
                 x_topic,
                 vocab_size,
                 batch_size,
                 seq_len,
                 gen_emb_dim,
                 mem_slots,
                 head_size,
                 num_heads,
                 hidden_dim,
                 start_token,
                 use_lambda=True,
                 **kwargs):
        self.start_tokens = tf.constant([start_token] * batch_size,
                                        dtype=tf.int32)
        self.output_memory_size = mem_slots * head_size * num_heads
        self.seq_len = seq_len
        self.x_real = x_real
        self.x_topic = x_topic
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.kwargs = kwargs
        self.vocab_size = vocab_size
        self.temperature = temperature
        self.topic_in_memory = kwargs["TopicInMemory"]
        self.no_topic = kwargs["NoTopic"]
        self.gen_emb_dim = gen_emb_dim

        self.g_embeddings = tf.get_variable(
            'g_emb',
            shape=[vocab_size, gen_emb_dim],
            initializer=create_linear_initializer(vocab_size))
        self.gen_mem = RelationalMemory(mem_slots=mem_slots,
                                        head_size=head_size,
                                        num_heads=num_heads)
        self.g_output_unit = create_output_unit(self.output_memory_size,
                                                vocab_size)
        self.g_topic_embedding = create_topic_embedding_unit(
            vocab_size, gen_emb_dim)
        self.g_output_unit_lambda = create_output_unit_lambda(
            output_size=1,
            input_size=self.output_memory_size,
            additive_scope="_lambda",
            min_value=0.01)
        self.first_embedding = self.first_embedding_function

        # initial states
        self.init_states = self.gen_mem.initial_state(batch_size)
        self.create_recurrent_adv()
        self.create_pretrain()
コード例 #2
0
    def __init__(self, x_real, temperature, x_sentiment, vocab_size,
                 batch_size, seq_len, gen_emb_dim, mem_slots, head_size,
                 num_heads, hidden_dim, start_token, sentiment_num, **kwargs):
        self.generated_num = None
        self.x_real = x_real
        self.batch_size = batch_size
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.start_tokens = tf.constant([start_token] * batch_size,
                                        dtype=tf.int32)
        output_memory_size = mem_slots * head_size * num_heads
        self.temperature = temperature
        self.x_sentiment = x_sentiment

        self.g_embeddings = tf.get_variable(
            'g_emb',
            shape=[vocab_size, gen_emb_dim],
            initializer=create_linear_initializer(vocab_size))
        self.gen_mem = RelationalMemory(mem_slots=mem_slots,
                                        head_size=head_size,
                                        num_heads=num_heads)
        self.g_output_unit = create_output_unit(output_memory_size, vocab_size)

        # managing of attributes
        self.g_sentiment = linear(input_=tf.one_hot(self.x_sentiment,
                                                    sentiment_num),
                                  output_size=gen_emb_dim,
                                  use_bias=True,
                                  scope="linear_x_sentiment")

        # self_attention_unit = create_self_attention_unit(scope="attribute_self_attention") #todo

        # initial states
        self.init_states = self.gen_mem.initial_state(batch_size)

        # ---------- generate tokens and approximated one-hot results (Adversarial) ---------
        self.gen_o = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                  size=seq_len,
                                                  dynamic_size=False,
                                                  infer_shape=True)
        self.gen_x = tensor_array_ops.TensorArray(dtype=tf.int32,
                                                  size=seq_len,
                                                  dynamic_size=False,
                                                  infer_shape=True)
        self.gen_x_onehot_adv = tensor_array_ops.TensorArray(
            dtype=tf.float32,
            size=seq_len,
            dynamic_size=False,
            infer_shape=True)
        self.pretrain_loss = None
        self.generate_recurrence_graph()
        self.generate_pretrain()
コード例 #3
0
ファイル: amazon_attr.py プロジェクト: federicoBetti/RelGAN
    def __init__(self, x_real, temperature, x_user, x_product, x_rating,
                 vocab_size, batch_size, seq_len, gen_emb_dim, mem_slots,
                 head_size, num_heads, hidden_dim, start_token, user_num,
                 product_num, rating_num, **kwargs):
        self.generated_num = None
        self.batch_size = batch_size
        self.vocab_size = vocab_size
        self.temperature = temperature
        self.seq_len = seq_len
        self.gen_emb_dim = gen_emb_dim
        self.x_real = x_real
        self.start_tokens = tf.constant([start_token] * batch_size,
                                        dtype=tf.int32)
        self.x_user = x_user
        self.x_rating = x_rating
        self.x_product = x_product
        output_memory_size = mem_slots * head_size * num_heads

        self.g_embeddings = tf.get_variable(
            'g_emb',
            shape=[vocab_size, gen_emb_dim],
            initializer=create_linear_initializer(vocab_size))
        self.gen_mem = RelationalMemory(mem_slots=mem_slots,
                                        head_size=head_size,
                                        num_heads=num_heads)
        self.g_output_unit = create_output_unit(output_memory_size, vocab_size)

        # managing of attributes
        self.g_user = linear(input_=tf.one_hot(x_user, user_num),
                             output_size=gen_emb_dim,
                             use_bias=True,
                             scope="linear_x_user")
        self.g_product = linear(input_=tf.one_hot(x_product, product_num),
                                output_size=gen_emb_dim,
                                use_bias=True,
                                scope="linear_x_product")
        self.g_rating = linear(input_=tf.one_hot(x_rating, rating_num),
                               output_size=gen_emb_dim,
                               use_bias=True,
                               scope="linear_x_rating")
        self.g_attribute = linear(input_=tf.concat(
            [self.g_user, self.g_product, self.g_rating], axis=1),
                                  output_size=self.gen_emb_dim,
                                  use_bias=True,
                                  scope="linear_after_concat")

        # self_attention_unit = create_self_attention_unit(scope="attribute_self_attention") #todo

        # initial states
        self.init_states = self.gen_mem.initial_state(batch_size)
        self.create_recurrence()
        self.create_pretrain()
コード例 #4
0
ファイル: rmc_vanilla.py プロジェクト: lethaiq/RelGAN
def generator(x_real, temperature, vocab_size, batch_size, seq_len,
              gen_emb_dim, mem_slots, head_size, num_heads, hidden_dim,
              start_token):
    start_tokens = tf.constant([start_token] * batch_size, dtype=tf.int32)
    output_size = mem_slots * head_size * num_heads

    # build relation memory module
    g_embeddings = tf.get_variable(
        'g_emb',
        shape=[vocab_size, gen_emb_dim],
        initializer=create_linear_initializer(vocab_size))
    gen_mem = RelationalMemory(mem_slots=mem_slots,
                               head_size=head_size,
                               num_heads=num_heads)
    g_output_unit = create_output_unit(output_size, vocab_size)

    # initial states
    init_states = gen_mem.initial_state(batch_size)

    # ---------- generate tokens and approximated one-hot results (Adversarial) ---------
    gen_o = tensor_array_ops.TensorArray(dtype=tf.float32,
                                         size=seq_len,
                                         dynamic_size=False,
                                         infer_shape=True)
    gen_x = tensor_array_ops.TensorArray(dtype=tf.int32,
                                         size=seq_len,
                                         dynamic_size=False,
                                         infer_shape=True)
    gen_x_onehot_adv = tensor_array_ops.TensorArray(
        dtype=tf.float32, size=seq_len, dynamic_size=False,
        infer_shape=True)  # generator output (relaxed of gen_x)

    # the generator recurrent module used for adversarial training
    def _gen_recurrence(i, x_t, h_tm1, gen_o, gen_x, gen_x_onehot_adv):
        mem_o_t, h_t = gen_mem(x_t, h_tm1)  # hidden_memory_tuple
        o_t = g_output_unit(mem_o_t)  # batch x vocab, logits not probs
        gumbel_t = add_gumbel(o_t)
        next_token = tf.stop_gradient(
            tf.argmax(gumbel_t, axis=1, output_type=tf.int32))
        next_token_onehot = tf.one_hot(next_token, vocab_size, 1.0, 0.0)

        x_onehot_appr = tf.nn.softmax(tf.multiply(
            gumbel_t, temperature))  # one-hot-like, [batch_size x vocab_size]

        # x_tp1 = tf.matmul(x_onehot_appr, g_embeddings)  # approximated embeddings, [batch_size x emb_dim]
        x_tp1 = tf.nn.embedding_lookup(
            g_embeddings, next_token)  # embeddings, [batch_size x emb_dim]

        gen_o = gen_o.write(i,
                            tf.reduce_sum(
                                tf.multiply(next_token_onehot, x_onehot_appr),
                                1))  # [batch_size], prob
        gen_x = gen_x.write(i, next_token)  # indices, [batch_size]

        gen_x_onehot_adv = gen_x_onehot_adv.write(i, x_onehot_appr)

        return i + 1, x_tp1, h_t, gen_o, gen_x, gen_x_onehot_adv

    # build a graph for outputting sequential tokens
    _, _, _, gen_o, gen_x, gen_x_onehot_adv = control_flow_ops.while_loop(
        cond=lambda i, _1, _2, _3, _4, _5: i < seq_len,
        body=_gen_recurrence,
        loop_vars=(tf.constant(0, dtype=tf.int32),
                   tf.nn.embedding_lookup(g_embeddings, start_tokens),
                   init_states, gen_o, gen_x, gen_x_onehot_adv))

    gen_o = tf.transpose(gen_o.stack(), perm=[1, 0])  # batch_size x seq_len
    gen_x = tf.transpose(gen_x.stack(), perm=[1, 0])  # batch_size x seq_len

    gen_x_onehot_adv = tf.transpose(
        gen_x_onehot_adv.stack(),
        perm=[1, 0, 2])  # batch_size x seq_len x vocab_size

    # ----------- pre-training for generator -----------------
    x_emb = tf.transpose(tf.nn.embedding_lookup(g_embeddings, x_real),
                         perm=[1, 0, 2])  # seq_len x batch_size x emb_dim
    g_predictions = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                 size=seq_len,
                                                 dynamic_size=False,
                                                 infer_shape=True)

    ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32, size=seq_len)
    ta_emb_x = ta_emb_x.unstack(x_emb)

    # the generator recurrent moddule used for pre-training
    def _pretrain_recurrence(i, x_t, h_tm1, g_predictions):
        mem_o_t, h_t = gen_mem(x_t, h_tm1)
        o_t = g_output_unit(mem_o_t)
        g_predictions = g_predictions.write(
            i, tf.nn.softmax(o_t))  # batch_size x vocab_size
        x_tp1 = ta_emb_x.read(i)
        return i + 1, x_tp1, h_t, g_predictions

    # build a graph for outputting sequential tokens
    _, _, _, g_predictions = control_flow_ops.while_loop(
        cond=lambda i, _1, _2, _3: i < seq_len,
        body=_pretrain_recurrence,
        loop_vars=(tf.constant(0, dtype=tf.int32),
                   tf.nn.embedding_lookup(g_embeddings, start_tokens),
                   init_states, g_predictions))

    g_predictions = tf.transpose(
        g_predictions.stack(),
        perm=[1, 0, 2])  # batch_size x seq_length x vocab_size

    # pre-training loss
    pretrain_loss = -tf.reduce_sum(
        tf.one_hot(tf.to_int32(tf.reshape(x_real, [-1])), vocab_size, 1.0, 0.0)
        * tf.log(
            tf.clip_by_value(tf.reshape(g_predictions, [-1, vocab_size]),
                             1e-20, 1.0))) / (seq_len * batch_size)

    return gen_x_onehot_adv, gen_x, pretrain_loss, gen_o
コード例 #5
0
class ReviewGenerator:
    def __init__(self, x_real, temperature, x_sentiment, vocab_size,
                 batch_size, seq_len, gen_emb_dim, mem_slots, head_size,
                 num_heads, hidden_dim, start_token, sentiment_num, **kwargs):
        self.generated_num = None
        self.x_real = x_real
        self.batch_size = batch_size
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.start_tokens = tf.constant([start_token] * batch_size,
                                        dtype=tf.int32)
        output_memory_size = mem_slots * head_size * num_heads
        self.temperature = temperature
        self.x_sentiment = x_sentiment

        self.g_embeddings = tf.get_variable(
            'g_emb',
            shape=[vocab_size, gen_emb_dim],
            initializer=create_linear_initializer(vocab_size))
        self.gen_mem = RelationalMemory(mem_slots=mem_slots,
                                        head_size=head_size,
                                        num_heads=num_heads)
        self.g_output_unit = create_output_unit(output_memory_size, vocab_size)

        # managing of attributes
        self.g_sentiment = linear(input_=tf.one_hot(self.x_sentiment,
                                                    sentiment_num),
                                  output_size=gen_emb_dim,
                                  use_bias=True,
                                  scope="linear_x_sentiment")

        # self_attention_unit = create_self_attention_unit(scope="attribute_self_attention") #todo

        # initial states
        self.init_states = self.gen_mem.initial_state(batch_size)

        # ---------- generate tokens and approximated one-hot results (Adversarial) ---------
        self.gen_o = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                  size=seq_len,
                                                  dynamic_size=False,
                                                  infer_shape=True)
        self.gen_x = tensor_array_ops.TensorArray(dtype=tf.int32,
                                                  size=seq_len,
                                                  dynamic_size=False,
                                                  infer_shape=True)
        self.gen_x_onehot_adv = tensor_array_ops.TensorArray(
            dtype=tf.float32,
            size=seq_len,
            dynamic_size=False,
            infer_shape=True)
        self.pretrain_loss = None
        self.generate_recurrence_graph()
        self.generate_pretrain()

    def generate_recurrence_graph(self):
        def _gen_recurrence(i, x_t, h_tm1, gen_o, gen_x, gen_x_onehot_adv):
            mem_o_t, h_t = self.gen_mem(
                x_t, h_tm1
            )  # hidden_memory_tuple, output della memoria che si potrebbe riutilizzare
            mem_o_t, h_t = self.gen_mem(self.g_sentiment, h_t)
            # mem_o_t, h_t = gen_mem(self_attention_unit(), h_t) # todo
            o_t = self.g_output_unit(mem_o_t)  # batch x vocab, logits not prob

            gumbel_t = add_gumbel(o_t)
            next_token = tf.cast(tf.argmax(gumbel_t, axis=1), tf.int32)
            x_onehot_appr = tf.nn.softmax(
                tf.multiply(gumbel_t, self.temperature, name="gumbel_x_temp"),
                name="softmax_gumbel_temp"
            )  # one-hot-like, [batch_size x vocab_size]

            x_tp1 = tf.nn.embedding_lookup(
                self.g_embeddings,
                next_token)  # embeddings, [batch_size x emb_dim]
            gen_o = gen_o.write(i,
                                tf.reduce_sum(
                                    tf.multiply(
                                        tf.one_hot(next_token,
                                                   self.vocab_size, 1.0, 0.0),
                                        tf.nn.softmax(o_t)),
                                    1))  # [batch_size] , prob
            gen_x = gen_x.write(i, next_token)  # indices, [batch_size]
            gen_x_onehot_adv = gen_x_onehot_adv.write(i, x_onehot_appr)

            return i + 1, x_tp1, h_t, gen_o, gen_x, gen_x_onehot_adv

        _, _, _, gen_o, gen_x, gen_x_onehot_adv = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4, _5: i < self.seq_len,
            body=_gen_recurrence,
            loop_vars=(
                tf.constant(0, dtype=tf.int32),
                tf.nn.embedding_lookup(self.g_embeddings, self.start_tokens),
                # todo si potrebbe pensare di modificare il primo input
                self.init_states,
                self.gen_o,
                self.gen_x,
                self.gen_x_onehot_adv),
            name="while_adv_recurrence")

        gen_x = gen_x.stack()  # seq_len x batch_size
        self.gen_x = tf.transpose(gen_x, perm=[1, 0],
                                  name="gen_x_trans")  # batch_size x seq_len

        gen_o = gen_o.stack()
        self.gen_o = tf.transpose(gen_o, perm=[1, 0], name="gen_o_trans")

        gen_x_onehot_adv = gen_x_onehot_adv.stack()
        self.gen_x_onehot_adv = tf.transpose(
            gen_x_onehot_adv, perm=[1, 0, 2],
            name="gen_x_onehot_adv_trans")  # batch_size x seq_len x vocab_size

    def generate_pretrain(self):
        # ----------- pre-training for generator -----------------
        x_emb = tf.transpose(
            tf.nn.embedding_lookup(self.g_embeddings, self.x_real),
            perm=[1, 0, 2],
            name="input_embedding")  # seq_len x batch_size x emb_dim
        g_predictions = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                     size=self.seq_len,
                                                     dynamic_size=False,
                                                     infer_shape=True)

        ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                size=self.seq_len)
        ta_emb_x = ta_emb_x.unstack(x_emb)

        def _pretrain_recurrence(i, x_t, h_tm1, g_predictions):
            mem_o_t, h_t = self.gen_mem(x_t, h_tm1)
            mem_o_t, h_t = self.gen_mem(self.g_sentiment, h_t)
            o_t = self.g_output_unit(mem_o_t)
            g_predictions = g_predictions.write(
                i, tf.nn.softmax(o_t))  # batch_size x vocab_size
            x_tp1 = ta_emb_x.read(i)
            return i + 1, x_tp1, h_t, g_predictions

        _, _, _, g_predictions = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3: i < self.seq_len,
            body=_pretrain_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings,
                                              self.start_tokens),
                       self.init_states, g_predictions),
            name="while_pretrain")

        g_predictions = tf.transpose(
            g_predictions.stack(), perm=[1, 0, 2],
            name="g_predictions_trans")  # batch_size x seq_length x vocab_size

        # pretraining loss
        with tf.variable_scope("pretrain_loss_computation"):
            self.pretrain_loss = -tf.reduce_sum(
                tf.one_hot(tf.cast(tf.reshape(self.x_real, [-1]), tf.int32),
                           self.vocab_size, 1.0, 0.0) *
                tf.log(
                    tf.clip_by_value(
                        tf.reshape(g_predictions, [-1, self.vocab_size]),
                        1e-20, 1.0))) / (self.seq_len * self.batch_size)

    def pretrain_epoch(self, oracle_loader, sess, **kwargs):
        supervised_g_losses = []
        for it in range(oracle_loader.num_batch):
            sentiment, sentence = oracle_loader.next_batch()
            n = np.zeros((self.batch_size, self.seq_len))
            for ind, el in enumerate(sentence):
                n[ind] = el

            try:
                _ = kwargs['g_pretrain_op']
                _, g_loss = sess.run(
                    [kwargs['g_pretrain_op'], self.pretrain_loss],
                    feed_dict={
                        self.x_real: n,
                        self.x_sentiment: sentiment
                    })
            except KeyError:
                g_loss = sess.run(self.pretrain_loss,
                                  feed_dict={
                                      self.x_real: n,
                                      self.x_sentiment: sentiment
                                  })

            supervised_g_losses.append(g_loss)

        return np.mean(supervised_g_losses)

    def generate_json(self, oracle_loader: RealDataCustomerReviewsLoader, sess,
                      **config):
        generated_samples, input_sentiment = [], []
        sentence_generated_from = []

        max_gen = int(self.generated_num / self.batch_size)  # - 155 # 156
        for ii in range(max_gen):
            sentiment, sentences = oracle_loader.random_batch()
            feed_dict = {self.x_sentiment: sentiment}
            sentence_generated_from.extend(sentences)
            gen_x_res = sess.run([self.gen_x], feed_dict=feed_dict)

            generated_samples.extend([x for a in gen_x_res for x in a])
            input_sentiment.extend(sentiment)

        json_file = {'sentences': []}
        for sent, input_sent in zip(generated_samples, input_sentiment):
            json_file['sentences'].append({
                'generated_sentence':
                " ".join([
                    oracle_loader.model_index_word_dict[str(el)] for el in sent
                    if el < len(oracle_loader.model_index_word_dict)
                ]),
                'sentiment':
                input_sent
            })

        return json_file
コード例 #6
0
def generator(x_real, temperature, vocab_size, batch_size, seq_len,
              gen_emb_dim, mem_slots, head_size, num_heads, hidden_dim,
              start_token):
    start_tokens = tf.constant([start_token] * batch_size, dtype=tf.int32)
    output_size = mem_slots * head_size * num_heads

    # build relation memory module
    g_embeddings = tf.get_variable(
        'g_emb',
        shape=[vocab_size, gen_emb_dim],
        initializer=create_linear_initializer(vocab_size))
    gen_mem = RelationalMemory(mem_slots=mem_slots,
                               head_size=head_size,
                               num_heads=num_heads)
    g_output_unit = create_output_unit(output_size, vocab_size)

    # initial states
    init_states = gen_mem.initial_state(batch_size)

    # ---------- generate tokens and approximated one-hot results (Adversarial) ---------
    gen_o = tensor_array_ops.TensorArray(dtype=tf.float32,
                                         size=seq_len,
                                         dynamic_size=False,
                                         infer_shape=True)
    gen_x = tensor_array_ops.TensorArray(dtype=tf.int32,
                                         size=seq_len,
                                         dynamic_size=False,
                                         infer_shape=True)
    gen_x_sample = tensor_array_ops.TensorArray(dtype=tf.int32,
                                                size=seq_len,
                                                dynamic_size=False,
                                                infer_shape=True)
    gen_x_onehot_adv = tensor_array_ops.TensorArray(
        dtype=tf.float32, size=seq_len, dynamic_size=False,
        infer_shape=True)  # generator output (relaxed of gen_x)

    # the generator recurrent module used for adversarial training
    def _gen_recurrence(i, x_t, h_tm1, gen_o, gen_x, gen_x_sample,
                        gen_x_onehot_adv):
        mem_o_t, h_t = gen_mem(x_t, h_tm1)  # hidden_memory_tuple
        o_t = g_output_unit(mem_o_t)  # batch x vocab, logits not probs
        gumbel_t = add_gumbel(o_t)
        next_token = tf.stop_gradient(
            tf.argmax(gumbel_t, axis=1, output_type=tf.int32))
        next_token_sample = tf.stop_gradient(
            tf.multinomial(tf.log(tf.clip_by_value(gumbel_t, 1e-20, 1.0)),
                           1,
                           output_dtype=tf.int32))
        next_token_sample = tf.reshape(next_token_sample, [-1])
        next_token_onehot = tf.one_hot(next_token, vocab_size, 1.0, 0.0)

        x_onehot_appr = tf.nn.softmax(tf.multiply(
            gumbel_t, temperature))  # one-hot-like, [batch_size x vocab_size]

        # x_tp1 = tf.matmul(x_onehot_appr, g_embeddings)  # approximated embeddings, [batch_size x emb_dim]
        x_tp1 = tf.nn.embedding_lookup(
            g_embeddings, next_token)  # embeddings, [batch_size x emb_dim]

        gen_o = gen_o.write(i,
                            tf.reduce_sum(
                                tf.multiply(next_token_onehot, x_onehot_appr),
                                1))  # [batch_size], prob
        gen_x = gen_x.write(i, next_token)  # indices, [batch_size]
        gen_x_sample = gen_x_sample.write(
            i, next_token_sample)  # indices, [batch_size]

        gen_x_onehot_adv = gen_x_onehot_adv.write(i, x_onehot_appr)

        return i + 1, x_tp1, h_t, gen_o, gen_x, gen_x_sample, gen_x_onehot_adv

    # build a graph for outputting sequential tokens
    _, _, _, gen_o, gen_x, gen_x_sample, gen_x_onehot_adv = control_flow_ops.while_loop(
        cond=lambda i, _1, _2, _3, _4, _5, _6: i < seq_len,
        body=_gen_recurrence,
        loop_vars=(tf.constant(0, dtype=tf.int32),
                   tf.nn.embedding_lookup(g_embeddings, start_tokens),
                   init_states, gen_o, gen_x, gen_x_sample, gen_x_onehot_adv))

    gen_o = tf.transpose(gen_o.stack(), perm=[1, 0])  # batch_size x seq_len
    gen_x = tf.transpose(gen_x.stack(), perm=[1, 0])  # batch_size x seq_len
    gen_x_sample = tf.transpose(gen_x_sample.stack(),
                                perm=[1, 0])  # batch_size x seq_len

    gen_x_onehot_adv = tf.transpose(
        gen_x_onehot_adv.stack(),
        perm=[1, 0, 2])  # batch_size x seq_len x vocab_size

    # ----------- pre-training for generator -----------------
    x_emb = tf.transpose(tf.nn.embedding_lookup(g_embeddings, x_real),
                         perm=[1, 0, 2])  # seq_len x batch_size x emb_dim
    g_predictions = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                 size=seq_len,
                                                 dynamic_size=False,
                                                 infer_shape=True)

    ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32, size=seq_len)
    ta_emb_x = ta_emb_x.unstack(x_emb)

    # the generator recurrent moddule used for pre-training
    def _pretrain_recurrence(i, x_t, h_tm1, g_predictions):
        mem_o_t, h_t = gen_mem(x_t, h_tm1)
        o_t = g_output_unit(mem_o_t)
        g_predictions = g_predictions.write(
            i, tf.nn.softmax(o_t))  # batch_size x vocab_size
        x_tp1 = ta_emb_x.read(i)
        return i + 1, x_tp1, h_t, g_predictions

    # build a graph for outputting sequential tokens
    _, _, _, g_predictions = control_flow_ops.while_loop(
        cond=lambda i, _1, _2, _3: i < seq_len,
        body=_pretrain_recurrence,
        loop_vars=(tf.constant(0, dtype=tf.int32),
                   tf.nn.embedding_lookup(g_embeddings, start_tokens),
                   init_states, g_predictions))

    g_predictions = tf.transpose(
        g_predictions.stack(),
        perm=[1, 0, 2])  # batch_size x seq_length x vocab_size

    # pre-training loss
    pretrain_loss = -tf.reduce_sum(
        tf.one_hot(tf.to_int32(tf.reshape(x_real, [-1])), vocab_size, 1.0, 0.0)
        * tf.log(
            tf.clip_by_value(tf.reshape(g_predictions, [-1, vocab_size]),
                             1e-20, 1.0))) / (seq_len * batch_size)

    # Policy gradients tensors and computational graph ========================================================

    # initial states
    r_init_states = gen_mem.initial_state(batch_size)

    r_gen_x = tensor_array_ops.TensorArray(dtype=tf.int32,
                                           size=seq_len,
                                           dynamic_size=False,
                                           infer_shape=True)
    r_gen_x_sample = tensor_array_ops.TensorArray(dtype=tf.int32,
                                                  size=seq_len,
                                                  dynamic_size=False,
                                                  infer_shape=True)

    given_num_ph = tf.placeholder(tf.int32)
    r_x = tf.placeholder(
        tf.int32, shape=[batch_size, seq_len]
    )  # sequence of tokens generated by generator as actions (a) for policy gradients

    r_x_emb = tf.transpose(tf.nn.embedding_lookup(g_embeddings, r_x),
                           perm=[1, 0, 2])  # seq_len x batch_size x emb_dim
    r_ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32, size=seq_len)
    r_ta_emb_x = r_ta_emb_x.unstack(r_x_emb)

    ta_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=seq_len)
    ta_x = ta_x.unstack(tf.transpose(r_x, perm=[1, 0]))

    # When current index i < given_num, use the provided tokens as the input at each time step
    def _g_recurrence_reward_1(i, x_t, h_tm1, given_num, r_gen_x,
                               r_gen_x_sample):
        _, h_t = gen_mem(x_t, h_tm1)  # hidden_memory_tuple
        # o_t = g_output_unit(mem_o_t)  # batch x vocab, logits not probs
        x_tp1 = r_ta_emb_x.read(i)
        r_gen_x = r_gen_x.write(i, ta_x.read(i))
        r_gen_x_sample = r_gen_x_sample.write(i, ta_x.read(i))
        return i + 1, x_tp1, h_t, given_num, r_gen_x, r_gen_x_sample

    # When current index i >= given_num, start roll-out, use the output as time step t as the input at time step t+1
    def _g_recurrence_reward_2(i, x_t, h_tm1, given_num, r_gen_x,
                               r_gen_x_sample):
        mem_o_t, h_t = gen_mem(x_t, h_tm1)  # hidden_memory_tuple
        o_t = g_output_unit(mem_o_t)  # batch x vocab, logits not probs

        if sample_output_is_gumbel:
            gumbel_t = add_gumbel(o_t)
            next_token = tf.stop_gradient(
                tf.argmax(gumbel_t, axis=1, output_type=tf.int32))
            next_token_sample = tf.stop_gradient(
                tf.multinomial(tf.log(tf.clip_by_value(gumbel_t, 1e-20, 1.0)),
                               1,
                               output_dtype=tf.int32))
            next_token_sample = tf.reshape(next_token_sample, [-1])
        else:
            log_prob = tf.log(tf.nn.softmax(o_t))
            next_token = tf.cast(
                tf.reshape(tf.multinomial(log_prob, 1), [batch_size]),
                tf.int32)
            next_token_sample = tf.zeros([batch_size])

        # x_tp1 = tf.matmul(x_onehot_appr, g_embeddings)  # approximated embeddings, [batch_size x emb_dim]
        x_tp1 = tf.nn.embedding_lookup(
            g_embeddings, next_token)  # embeddings, [batch_size x emb_dim]
        r_gen_x = r_gen_x.write(i, next_token)  # indices, [batch_size]
        r_gen_x_sample = r_gen_x_sample.write(
            i, next_token_sample)  # indices, [batch_size]
        return i + 1, x_tp1, h_t, given_num, r_gen_x, r_gen_x_sample

    r_i, r_x_t, r_h_tm1, given_num, r_gen_x, r_gen_x_sample = control_flow_ops.while_loop(
        cond=lambda i, _1, _2, given_num, _4, _5: i < given_num,
        body=_g_recurrence_reward_1,
        loop_vars=(tf.constant(0, dtype=tf.int32),
                   tf.nn.embedding_lookup(g_embeddings, start_tokens),
                   r_init_states, given_num_ph, r_gen_x, r_gen_x_sample))

    _, _, _, _, r_gen_x, r_gen_x_sample = control_flow_ops.while_loop(
        cond=lambda i, _1, _2, _3, _4, _5: i < seq_len,
        body=_g_recurrence_reward_2,
        loop_vars=(r_i, r_x_t, r_h_tm1, given_num, r_gen_x, r_gen_x_sample))

    r_gen_x = r_gen_x.stack()  # seq_length x batch_size
    r_gen_x = tf.transpose(r_gen_x, perm=[1, 0])  # batch_size x seq_length
    r_gen_x_sample = r_gen_x_sample.stack()  # seq_length x batch_size
    r_gen_x_sample = tf.transpose(r_gen_x_sample,
                                  perm=[1, 0])  # batch_size x seq_length

    return gen_x_onehot_adv, gen_x, gen_x_sample, pretrain_loss, gen_o, given_num_ph, r_x, r_gen_x, r_gen_x_sample
コード例 #7
0
ファイル: rmc_att_topic.py プロジェクト: federicoBetti/RelGAN
class generator:
    def __init__(self,
                 x_real,
                 temperature,
                 x_topic,
                 vocab_size,
                 batch_size,
                 seq_len,
                 gen_emb_dim,
                 mem_slots,
                 head_size,
                 num_heads,
                 hidden_dim,
                 start_token,
                 use_lambda=True,
                 **kwargs):
        self.start_tokens = tf.constant([start_token] * batch_size,
                                        dtype=tf.int32)
        self.output_memory_size = mem_slots * head_size * num_heads
        self.seq_len = seq_len
        self.x_real = x_real
        self.x_topic = x_topic
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.kwargs = kwargs
        self.vocab_size = vocab_size
        self.temperature = temperature
        self.topic_in_memory = kwargs["TopicInMemory"]
        self.no_topic = kwargs["NoTopic"]
        self.gen_emb_dim = gen_emb_dim

        self.g_embeddings = tf.get_variable(
            'g_emb',
            shape=[vocab_size, gen_emb_dim],
            initializer=create_linear_initializer(vocab_size))
        self.gen_mem = RelationalMemory(mem_slots=mem_slots,
                                        head_size=head_size,
                                        num_heads=num_heads)
        self.g_output_unit = create_output_unit(self.output_memory_size,
                                                vocab_size)
        self.g_topic_embedding = create_topic_embedding_unit(
            vocab_size, gen_emb_dim)
        self.g_output_unit_lambda = create_output_unit_lambda(
            output_size=1,
            input_size=self.output_memory_size,
            additive_scope="_lambda",
            min_value=0.01)
        self.first_embedding = self.first_embedding_function

        # initial states
        self.init_states = self.gen_mem.initial_state(batch_size)
        self.create_recurrent_adv()
        self.create_pretrain()

    def create_recurrent_adv(self):
        # ---------- generate tokens and approximated one-hot results (Adversarial) ---------
        gen_o = tensor_array_ops.TensorArray(dtype=tf.float32,
                                             size=self.seq_len,
                                             dynamic_size=False,
                                             infer_shape=True)
        gen_x = tensor_array_ops.TensorArray(dtype=tf.int32,
                                             size=self.seq_len,
                                             dynamic_size=False,
                                             infer_shape=True)
        gen_x_onehot_adv = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                        size=self.seq_len,
                                                        dynamic_size=False,
                                                        infer_shape=True)
        topicness_values = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                        size=self.seq_len,
                                                        dynamic_size=False,
                                                        infer_shape=True)
        gen_x_no_lambda = tensor_array_ops.TensorArray(dtype=tf.int32,
                                                       size=self.seq_len,
                                                       dynamic_size=False,
                                                       infer_shape=True)

        def _gen_recurrence(i, x_t, h_tm1, gen_o, gen_x, gen_x_onehot_adv,
                            lambda_values, gen_x_no_lambda):
            mem_o_t, h_t = self.gen_mem(x_t, h_tm1)  # hidden_memory_tuple
            if self.topic_in_memory and not self.no_topic:
                mem_o_t, h_t = self.gen_mem(
                    self.g_topic_embedding(self.x_topic), h_t)
            o_t = self.g_output_unit(mem_o_t)  # batch x vocab, logits not prob

            if not self.topic_in_memory and not self.kwargs["NoTopic"]:
                topic_vector = self.x_topic
                lambda_param = self.g_output_unit_lambda(mem_o_t)
                next_token_no_lambda = tf.cast(tf.argmax(o_t, axis=1),
                                               tf.int32)
                o_t = o_t + lambda_param * topic_vector
            else:
                lambda_param = tf.zeros(self.batch_size)
                next_token_no_lambda = tf.cast(tf.argmax(o_t, axis=1),
                                               tf.int32)

            gumbel_t = add_gumbel(o_t)

            next_token = tf.cast(tf.argmax(gumbel_t, axis=1), tf.int32)

            x_onehot_appr = tf.nn.softmax(
                tf.multiply(gumbel_t, self.temperature, name="gumbel_x_temp"),
                name="softmax_gumbel_temp"
            )  # one-hot-like, [batch_size x vocab_size]
            x_tp1 = tf.nn.embedding_lookup(
                self.g_embeddings,
                next_token)  # embeddings, [batch_size x emb_dim]
            gen_o = gen_o.write(i,
                                tf.reduce_sum(
                                    tf.multiply(
                                        tf.one_hot(next_token,
                                                   self.vocab_size, 1.0, 0.0),
                                        tf.nn.softmax(o_t)),
                                    1))  # [batch_size] , prob
            gen_x = gen_x.write(i, next_token)  # indices, [batch_size]
            gen_x_onehot_adv = gen_x_onehot_adv.write(i, x_onehot_appr)

            lambda_values = lambda_values.write(i, tf.squeeze(lambda_param))
            gen_x_no_lambda = gen_x_no_lambda.write(
                i, tf.squeeze(next_token_no_lambda))
            return i + 1, x_tp1, h_t, gen_o, gen_x, gen_x_onehot_adv, lambda_values, gen_x_no_lambda

        _, _, _, gen_o, gen_x, gen_x_onehot_adv, topicness_values, gen_x_no_lambda = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4, _5, _6, _7: i < self.seq_len,
            body=_gen_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32), self.first_embedding(),
                       self.init_states, gen_o, gen_x, gen_x_onehot_adv,
                       topicness_values, gen_x_no_lambda),
            name="while_adv_recurrence")

        gen_x = gen_x.stack()  # seq_len x batch_size
        self.gen_x = tf.transpose(gen_x, perm=[1, 0],
                                  name="gen_x_trans")  # batch_size x seq_len

        gen_o = gen_o.stack()
        self.gen_o = tf.transpose(gen_o, perm=[1, 0], name="gen_o_trans")

        gen_x_onehot_adv = gen_x_onehot_adv.stack()
        self.gen_x_onehot_adv = tf.transpose(
            gen_x_onehot_adv, perm=[1, 0, 2],
            name="gen_x_onehot_adv_trans")  # batch_size x seq_len x vocab_size

        topicness_values = topicness_values.stack()  # seq_len x batch_size
        self.topicness_values = tf.transpose(
            topicness_values, perm=[1, 0],
            name="lambda_values_trans")  # batch_size x seq_len

        gen_x_no_lambda = gen_x_no_lambda.stack()  # seq_len x batch_size
        self.gen_x_no_lambda = tf.transpose(
            gen_x_no_lambda, perm=[1, 0],
            name="gen_x_no_lambda_trans")  # batch_size x seq_len

    def create_pretrain(self):
        # ----------- pre-training for generator -----------------
        x_emb = tf.transpose(
            tf.nn.embedding_lookup(self.g_embeddings, self.x_real),
            perm=[1, 0, 2],
            name="input_embedding")  # seq_len x batch_size x emb_dim
        g_predictions = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                     size=self.seq_len,
                                                     dynamic_size=False,
                                                     infer_shape=True)

        ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                size=self.seq_len)
        ta_emb_x = ta_emb_x.unstack(x_emb)

        def _pretrain_recurrence(i, x_t, h_tm1, g_predictions):
            mem_o_t, h_t = self.gen_mem(x_t, h_tm1)
            if self.topic_in_memory and not self.no_topic:
                mem_o_t, h_t = self.gen_mem(
                    self.g_topic_embedding(self.x_topic), h_t)
            o_t = self.g_output_unit(mem_o_t)
            if not self.topic_in_memory and not self.no_topic:
                lambda_param = self.g_output_unit_lambda(mem_o_t)
                o_t = o_t + lambda_param * self.x_topic
            g_predictions = g_predictions.write(
                i, tf.nn.softmax(o_t))  # batch_size x vocab_size
            x_tp1 = ta_emb_x.read(i)
            return i + 1, x_tp1, h_t, g_predictions

        _, _, _, g_predictions = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3: i < self.seq_len,
            body=_pretrain_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32), self.first_embedding(),
                       self.init_states, g_predictions),
            name="while_pretrain")

        self.g_predictions = tf.transpose(
            g_predictions.stack(), perm=[1, 0, 2],
            name="g_predictions_trans")  # batch_size x seq_length x vocab_size

        # pretraining loss
        with tf.variable_scope("pretrain_loss_computation"):
            self.pretrain_loss = -tf.reduce_sum(
                tf.one_hot(tf.cast(tf.reshape(self.x_real, [-1]), tf.int32),
                           self.vocab_size, 1.0, 0.0) *
                tf.log(
                    tf.clip_by_value(
                        tf.reshape(self.g_predictions, [-1, self.vocab_size]),
                        1e-20, 1.0))) / (self.seq_len * self.batch_size)

    def pretrain_epoch(self, sess, oracle_loader, **kwargs):
        supervised_g_losses = []
        oracle_loader.reset_pointer()

        for it in tqdm(range(oracle_loader.num_batch)):
            text_batch, topic_batch = oracle_loader.next_batch(only_text=False)
            _, g_loss = sess.run([self.g_pretrain_op, self.pretrain_loss],
                                 feed_dict={
                                     self.x_real: text_batch,
                                     self.x_topic: topic_batch
                                 })
            supervised_g_losses.append(g_loss)

        return np.mean(supervised_g_losses)

    def generate_samples_topic(self, sess, oracle_loader, generated_num):
        generated_samples = []
        generated_samples_lambda = []
        sentence_generated_from = []
        generated_samples_no_lambda_words = []

        max_gen = int(generated_num / self.batch_size)  # - 155 # 156
        for ii in range(max_gen):
            if self.no_topic:
                gen_x_res = sess.run(self.gen_x)
                text_batch = oracle_loader.random_batch(only_text=True)
                sentence_generated_from.extend(text_batch)
            else:
                text_batch, topic_batch = oracle_loader.random_batch(
                    only_text=False)
                feed = {self.x_topic: topic_batch}
                sentence_generated_from.extend(text_batch)
                if self.topic_in_memory:
                    gen_x_res = sess.run(self.gen_x, feed_dict=feed)
                else:
                    gen_x_res, lambda_values_res, gen_x_no_lambda_res = sess.run(
                        [
                            self.gen_x, self.topicness_values,
                            self.gen_x_no_lambda
                        ],
                        feed_dict=feed)
                    generated_samples_lambda.extend(lambda_values_res)
                    generated_samples_no_lambda_words.extend(
                        gen_x_no_lambda_res)

            generated_samples.extend(gen_x_res)

        codes = ""
        codes_with_lambda = ""
        json_file = {'sentences': []}
        if self.no_topic or self.topic_in_memory:
            for sent, start_sentence in zip(generated_samples,
                                            sentence_generated_from):
                json_file['sentences'].append({
                    'real_starting':
                    get_sentence_from_index(
                        start_sentence, oracle_loader.model_index_word_dict),
                    'generated_sentence':
                    get_sentence_from_index(
                        sent, oracle_loader.model_index_word_dict)
                })
        else:
            for sent, lambda_value_sent, no_lambda_words, start_sentence in zip(
                    generated_samples, generated_samples_lambda,
                    generated_samples_no_lambda_words,
                    sentence_generated_from):
                sent_json = []
                for x, y, z in zip(sent, lambda_value_sent, no_lambda_words):
                    sent_json.append({
                        'word_code':
                        int(x),
                        'word_text':
                        '' if x == len(oracle_loader.model_index_word_dict)
                        else oracle_loader.model_index_word_dict[str(x)],
                        'lambda':
                        float(y),
                        'no_lambda_word':
                        '' if z == len(oracle_loader.model_index_word_dict)
                        else oracle_loader.model_index_word_dict[str(z)]
                    })
                    codes_with_lambda += "{} ({:.4f};{}) ".format(x, y, z)
                    codes += "{} ".format(x)
                json_file['sentences'].append({
                    'generated':
                    sent_json,
                    'real_starting':
                    get_sentence_from_index(
                        start_sentence, oracle_loader.model_index_word_dict),
                    'generated_sentence':
                    get_sentence_from_index(
                        sent, oracle_loader.model_index_word_dict)
                })

        return json_file

    def get_sentences(self, json_object):
        sentences = json_object['sentences']
        sent_number = 10
        sent = random.sample(sentences, sent_number)
        all_sentences = []
        for s in sent:
            if self.no_topic:
                all_sentences.append("{}".format(s['generated_sentence']))
            else:
                if self.topic_in_memory:
                    all_sentences.append("{} --- {}".format(
                        str(s['generated_sentence']), s['real_starting']))
                else:
                    word_with_no_lambda = []
                    for letter in s['generated']:
                        generated_word, real_word = letter[
                            'word_text'], letter['no_lambda_word']
                        if generated_word:
                            word_with_no_lambda.append("{} ({}, {})".format(
                                generated_word, letter['lambda'], real_word))
                    word_with_no_lambda = " ".join(word_with_no_lambda)
                    all_sentences.append("{} ---- {} ---- {}".format(
                        s['generated_sentence'], word_with_no_lambda,
                        s['real_starting']))
        return all_sentences

    def first_embedding_function(self):
        return tf.nn.embedding_lookup(self.g_embeddings, self.start_tokens)
        return tf.random.uniform([self.batch_size, self.gen_emb_dim])
コード例 #8
0
ファイル: amazon_attr.py プロジェクト: federicoBetti/RelGAN
class AmazonGenerator:
    def __init__(self, x_real, temperature, x_user, x_product, x_rating,
                 vocab_size, batch_size, seq_len, gen_emb_dim, mem_slots,
                 head_size, num_heads, hidden_dim, start_token, user_num,
                 product_num, rating_num, **kwargs):
        self.generated_num = None
        self.batch_size = batch_size
        self.vocab_size = vocab_size
        self.temperature = temperature
        self.seq_len = seq_len
        self.gen_emb_dim = gen_emb_dim
        self.x_real = x_real
        self.start_tokens = tf.constant([start_token] * batch_size,
                                        dtype=tf.int32)
        self.x_user = x_user
        self.x_rating = x_rating
        self.x_product = x_product
        output_memory_size = mem_slots * head_size * num_heads

        self.g_embeddings = tf.get_variable(
            'g_emb',
            shape=[vocab_size, gen_emb_dim],
            initializer=create_linear_initializer(vocab_size))
        self.gen_mem = RelationalMemory(mem_slots=mem_slots,
                                        head_size=head_size,
                                        num_heads=num_heads)
        self.g_output_unit = create_output_unit(output_memory_size, vocab_size)

        # managing of attributes
        self.g_user = linear(input_=tf.one_hot(x_user, user_num),
                             output_size=gen_emb_dim,
                             use_bias=True,
                             scope="linear_x_user")
        self.g_product = linear(input_=tf.one_hot(x_product, product_num),
                                output_size=gen_emb_dim,
                                use_bias=True,
                                scope="linear_x_product")
        self.g_rating = linear(input_=tf.one_hot(x_rating, rating_num),
                               output_size=gen_emb_dim,
                               use_bias=True,
                               scope="linear_x_rating")
        self.g_attribute = linear(input_=tf.concat(
            [self.g_user, self.g_product, self.g_rating], axis=1),
                                  output_size=self.gen_emb_dim,
                                  use_bias=True,
                                  scope="linear_after_concat")

        # self_attention_unit = create_self_attention_unit(scope="attribute_self_attention") #todo

        # initial states
        self.init_states = self.gen_mem.initial_state(batch_size)
        self.create_recurrence()
        self.create_pretrain()

    def multihead_attention(self, attribute):
        """Perform multi-head attention from 'Attention is All You Need'.

        Implementation of the attention mechanism from
        https://arxiv.org/abs/1706.03762.

        Args:
          memory: Memory tensor to perform attention on, with size [B, N, H*V].

        Returns:
          new_memory: New memory tensor.
        """
        key_size = 512
        head_size = 512
        num_heads = 2
        qkv_size = 2 * key_size + head_size
        total_size = qkv_size * num_heads  # Denote as F.
        batch_size = attribute.get_shape().as_list()[0]  # Denote as B
        qkv = linear(attribute, total_size, use_bias=False,
                     scope='lin_qkv')  # [B*N, F]
        qkv = tf.reshape(qkv, [batch_size, -1, total_size])  # [B, N, F]
        qkv = tf.contrib.layers.layer_norm(qkv, trainable=True)  # [B, N, F]

        # [B, N, F] -> [B, N, H, F/H]
        qkv_reshape = tf.reshape(qkv, [batch_size, -1, num_heads, qkv_size])

        # [B, N, H, F/H] -> [B, H, N, F/H]
        qkv_transpose = tf.transpose(qkv_reshape, [0, 2, 1, 3])
        q, k, v = tf.split(qkv_transpose, [key_size, key_size, head_size], -1)

        q *= qkv_size**-0.5
        dot_product = tf.matmul(q, k, transpose_b=True)  # [B, H, N, N]
        weights = tf.nn.softmax(dot_product)

        output = tf.matmul(weights, v)  # [B, H, N, V]

        # [B, H, N, V] -> [B, N, H, V]
        output_transpose = tf.transpose(output, [0, 2, 1, 3])

        # [B, N, H, V] -> [B, N, H * V]
        attended_attribute = tf.reshape(output_transpose, [batch_size, -1])
        return attended_attribute

    def create_recurrence(self):
        # ---------- generate tokens and approximated one-hot results (Adversarial) ---------
        gen_o = tensor_array_ops.TensorArray(dtype=tf.float32,
                                             size=self.seq_len,
                                             dynamic_size=False,
                                             infer_shape=True)
        gen_x = tensor_array_ops.TensorArray(dtype=tf.int32,
                                             size=self.seq_len,
                                             dynamic_size=False,
                                             infer_shape=True)
        gen_x_onehot_adv = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                        size=self.seq_len,
                                                        dynamic_size=False,
                                                        infer_shape=True)

        def _gen_recurrence(i, x_t, h_tm1, gen_o, gen_x, gen_x_onehot_adv):
            mem_o_t, h_t = self.gen_mem(x_t, h_tm1)  # hidden_memory_tuple
            mem_o_t, h_t = self.gen_mem(self.g_attribute, h_t)
            # mem_o_t, h_t = gen_mem(self_attention_unit(), h_t) # todo
            o_t = self.g_output_unit(mem_o_t)  # batch x vocab, logits not prob

            # print_op = tf.print("o_t shape", o_t.shape, ", o_t: ", o_t[0], output_stream=sys.stderr)

            gumbel_t = add_gumbel(o_t)
            next_token = tf.cast(tf.argmax(gumbel_t, axis=1), tf.int32)
            x_onehot_appr = tf.nn.softmax(
                tf.multiply(gumbel_t, self.temperature, name="gumbel_x_temp"),
                name="softmax_gumbel_temp"
            )  # one-hot-like, [batch_size x vocab_size]

            x_tp1 = tf.nn.embedding_lookup(
                self.g_embeddings,
                next_token)  # embeddings, [batch_size x emb_dim]
            gen_o = gen_o.write(i,
                                tf.reduce_sum(
                                    tf.multiply(
                                        tf.one_hot(next_token,
                                                   self.vocab_size, 1.0, 0.0),
                                        tf.nn.softmax(o_t)),
                                    1))  # [batch_size] , prob
            gen_x = gen_x.write(i, next_token)  # indices, [batch_size]
            gen_x_onehot_adv = gen_x_onehot_adv.write(i, x_onehot_appr)

            return i + 1, x_tp1, h_t, gen_o, gen_x, gen_x_onehot_adv

        _, _, _, gen_o, gen_x, gen_x_onehot_adv = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4, _5: i < self.seq_len,
            body=_gen_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings,
                                              self.start_tokens),
                       self.init_states, gen_o, gen_x, gen_x_onehot_adv),
            name="while_adv_recurrence")

        gen_x = gen_x.stack()  # seq_len x batch_size
        self.gen_x = tf.transpose(gen_x, perm=[1, 0],
                                  name="gen_x_trans")  # batch_size x seq_len

        gen_o = gen_o.stack()
        self.gen_o = tf.transpose(gen_o, perm=[1, 0], name="gen_o_trans")

        gen_x_onehot_adv = gen_x_onehot_adv.stack()
        self.gen_x_onehot_adv = tf.transpose(
            gen_x_onehot_adv, perm=[1, 0, 2],
            name="gen_x_onehot_adv_trans")  # batch_size x seq_len x vocab_size

    def create_pretrain(self):
        # ----------- pre-training for generator -----------------
        x_emb = tf.transpose(
            tf.nn.embedding_lookup(self.g_embeddings, self.x_real),
            perm=[1, 0, 2],
            name="input_embedding")  # seq_len x batch_size x emb_dim
        g_predictions = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                     size=self.seq_len,
                                                     dynamic_size=False,
                                                     infer_shape=True)

        ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                size=self.seq_len)
        ta_emb_x = ta_emb_x.unstack(x_emb)

        def _pretrain_recurrence(i, x_t, h_tm1, g_predictions):
            mem_o_t, h_t = self.gen_mem(x_t, h_tm1)
            mem_o_t, h_t = self.gen_mem(self.g_attribute, h_t)
            o_t = self.g_output_unit(mem_o_t)
            g_predictions = g_predictions.write(
                i, tf.nn.softmax(o_t))  # batch_size x vocab_size
            x_tp1 = ta_emb_x.read(i)
            return i + 1, x_tp1, h_t, g_predictions

        _, _, _, g_predictions = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3: i < self.seq_len,
            body=_pretrain_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings,
                                              self.start_tokens),
                       self.init_states, g_predictions),
            name="while_pretrain")

        g_predictions = tf.transpose(
            g_predictions.stack(), perm=[1, 0, 2],
            name="g_predictions_trans")  # batch_size x seq_length x vocab_size

        # pretraining loss
        with tf.variable_scope("pretrain_loss_computation"):
            self.pretrain_loss = -tf.reduce_sum(
                tf.one_hot(tf.cast(tf.reshape(self.x_real, [-1]), tf.int32),
                           self.vocab_size, 1.0, 0.0) *
                tf.log(
                    tf.clip_by_value(
                        tf.reshape(g_predictions, [-1, self.vocab_size]),
                        1e-20, 1.0))) / (self.seq_len * self.batch_size)

    def pretrain_epoch(self, oracle_loader, sess, **kwargs):
        supervised_g_losses = []
        oracle_loader.reset_pointer()

        n = np.zeros((self.batch_size, self.seq_len))
        for it in tqdm(range(oracle_loader.num_batch)):
            # t = time.time()
            user, product, rating, sentence = oracle_loader.next_batch()
            # t1 = time.time()
            for ind, el in enumerate(sentence):
                n[ind] = el
            # t2 = time.time()
            _, g_loss = sess.run(
                [kwargs['g_pretrain_op'], self.pretrain_loss],
                feed_dict={
                    self.x_real: n,
                    self.x_user: user,
                    self.x_product: product,
                    self.x_rating: rating
                })
            t3 = time.time()
            # print("Loader {}".format(t1 - t))
            # print("n: {}".format(t2 -t1))
            # print("pretrain: {}".format(t3 - t2))
            supervised_g_losses.append(g_loss)

        return np.mean(supervised_g_losses)

    def generate_samples(self, sess, oracle_loader, **tensors):
        generated_samples = []
        sentence_generated_from = []

        max_gen = int(self.generated_num / self.batch_size)  # - 155 # 156
        for ii in range(max_gen):
            user, product, rating, sentences = oracle_loader.random_batch(
                dataset=tensors['dataset'])
            feed_dict = {
                self.x_user: user,
                self.x_product: product,
                self.x_rating: rating
            }
            sentence_generated_from.extend(sentences)
            gen_x_res = sess.run([self.gen_x], feed_dict=feed_dict)

            generated_samples.extend([x for a in gen_x_res for x in a])

        json_file = {'sentences': []}
        for sent, start_sentence in zip(generated_samples,
                                        sentence_generated_from):
            json_file['sentences'].append({
                'real_starting':
                " ".join([
                    oracle_loader.model_index_word_dict[str(el)]
                    for el in start_sentence
                    if el < len(oracle_loader.model_index_word_dict)
                ]),
                'generated_sentence':
                " ".join([
                    oracle_loader.model_index_word_dict[str(el)] for el in sent
                    if el < len(oracle_loader.model_index_word_dict)
                ])
            })

        return json_file