コード例 #1
0
 def __init__(self,
              inputs,
              vocab,
              gamma,
              lambda_g_hidden,
              lambda_g_sentence,
              hparams=None):
     self._hparams = tx.HParams(hparams, None)
     self._prepare_inputs(inputs, vocab, gamma, lambda_g_hidden,
                          lambda_g_sentence),
     self._build_model()
コード例 #2
0
 def __init__(self,
              inputs,
              vocab,
              gamma,
              lambda_g,
              want_code,
              hparams=None):
     self._hparams = tx.HParams(hparams, None)
     self._want_code = want_code
     print(inputs)
     self._build_model(inputs, vocab, gamma, lambda_g)
コード例 #3
0
 def __init__(self,
              inputs,
              vocab,
              gamma,
              lambda_t_graph,
              lambda_t_sentence,
              ablation,
              hparams=None):
     self._hparams = tx.HParams(hparams, None)
     self._prepare_inputs(inputs, vocab, gamma, lambda_t_graph,
                          lambda_t_sentence, ablation),
     self._build_model()
コード例 #4
0
 def __init__(self,
              inputs,
              vocab,
              gamma,
              lambda_g,
              lambda_z1,
              lambda_z2,
              lambda_ae,
              hparams=None):
     self._hparams = tx.HParams(hparams, None)
     self._build_model(inputs, vocab, gamma, lambda_g, lambda_z1, lambda_z2,
                       lambda_ae)
コード例 #5
0
 def __init__(self, inputs, vocab, ctx_maxSeqLen, lr, hparams=None):
     self._hparams = tx.HParams(hparams, None)
     self._prepare_inputs(inputs, vocab, ctx_maxSeqLen, lr)
     self._build_model()
コード例 #6
0
 def __init__(self, inputs, finputs, minputs, vocab, gamma, hparams=None):
     self._hparams = tx.HParams(hparams, None)
     self._build_model(inputs, vocab, finputs, minputs, gamma)
コード例 #7
0
from tensorflow.python.ops import tensor_array_ops, control_flow_ops
from utils.ops import *
import texar as tx
from Gan_architecture import lstm_config
from texar.modules import UnidirectionalRNNEncoder, MLPTransformConnector, AttentionRNNDecoder, \
    Conv1DClassifier,GumbelSoftmaxEmbeddingHelper
from texar.utils import transformer_utils
import numpy as np

hparams = tx.HParams(lstm_config.model, None)


#The generator network based on the Relational Memory
def generator(text_ids, text_keyword_id, text_keyword_length, labels,
              text_length, temperature, vocab_size, batch_size, seq_len,
              gen_emb_dim, mem_slots, head_size, num_heads, hidden_dim,
              start_token):

    # Source word embedding
    src_word_embedder = tx.modules.WordEmbedder(vocab_size=vocab_size,
                                                hparams=hparams.embedder)
    src_word_embeds = src_word_embedder(text_keyword_id)

    encoder = UnidirectionalRNNEncoder(hparams=hparams.encoder)
    enc_outputs, final_state = encoder(inputs=src_word_embeds,
                                       sequence_length=text_keyword_length)

    # modify sentiment label
    label_connector = MLPTransformConnector(output_size=hparams.dim_c)
    state_connector = MLPTransformConnector(output_size=700)
def generator(text_ids, text_keyword_id, text_keyword_length, labels,
              text_length, temperature, vocab_size, batch_size, seq_len,
              gen_emb_dim, mem_slots, head_size, num_heads, hidden_dim,
              start_token):

    hparams = tx.HParams(lstm_config.model, None)
    # Source word embedding
    src_word_embedder = tx.modules.WordEmbedder(vocab_size=vocab_size,
                                                hparams=hparams.embedder)
    src_word_embeds = src_word_embedder(text_keyword_id)

    encoder = UnidirectionalRNNEncoder(hparams=hparams.encoder)
    enc_outputs, final_state = encoder(inputs=src_word_embeds,
                                       sequence_length=text_keyword_length)

    # modify sentiment label
    label_connector = MLPTransformConnector(output_size=hparams.dim_c)
    state_connector = MLPTransformConnector(output_size=700)

    labels = tf.to_float(tf.reshape(labels, [batch_size, 1]))
    c = label_connector(labels)
    c_ = label_connector(1 - labels)
    h = tf.concat([c, final_state], axis=1)
    h_ = tf.concat([c_, final_state], axis=1)

    state = state_connector(h)
    state_ = state_connector(h_)

    decoder = AttentionRNNDecoder(
        memory=enc_outputs,
        memory_sequence_length=text_keyword_length,
        cell_input_fn=lambda inputs, attention: inputs,
        vocab_size=vocab_size,
        hparams=hparams.decoder)

    # For training
    g_outputs, _, _ = decoder(initial_state=state,
                              inputs=text_ids,
                              embedding=src_word_embedder,
                              sequence_length=text_length - 1)

    start_tokens = np.ones(batch_size, int)
    end_token = int(2)
    # Greedy decoding, used in eval
    outputs_, _, length_ = decoder(decoding_strategy='infer_greedy',
                                   initial_state=state_,
                                   embedding=src_word_embedder,
                                   start_tokens=start_tokens,
                                   end_token=end_token)

    pretrain_loss = tx.losses.sequence_sparse_softmax_cross_entropy(
        labels=text_ids[:, 1:],
        logits=g_outputs.logits,
        sequence_length=text_length - 1,
        average_across_timesteps=True,
        sum_over_timesteps=False)

    # Gumbel-softmax decoding, used in training
    gumbel_helper = GumbelSoftmaxEmbeddingHelper(src_word_embedder.embedding,
                                                 start_tokens, end_token,
                                                 temperature)

    gumbel_outputs, _, sequence_lengths = decoder(helper=gumbel_helper,
                                                  initial_state=state_)

    # max_index = tf.argmax(gumbel_outputs.logits, axis=2)
    # gen_x_onehot_adv = tf.one_hot(max_index, vocab_size, sentiment.1.0, 0.0)

    gen_o = tf.reduce_sum(tf.reduce_max(outputs_.logits, axis=2))

    return gumbel_outputs.logits, outputs_.sample_id, pretrain_loss, gen_o