def __init__(self, embedding, start_tokens, end_token, vocab, reward_metric, ground_truth, ground_truth_length, lambdas): SampleEmbeddingHelper.__init__(self, embedding, start_tokens, end_token) self._vocab = vocab self._ground_truth = ground_truth self._lambdas = lambdas self._ground_truth_length = ground_truth_length self._metric = reward_metric
def sample(self, time, outputs, state, name=None): """ sample tokens for next step, notice the special form of 'state'([decoded_ids, rnn_state]) """ sample_method_sampler = \ tfpd.Categorical(probs=self._lambdas) sample_method_id = sample_method_sampler.sample() truth_feeding = lambda: tf.cond( tf.less(time, tf.shape(self._ground_truth)[1]), lambda: tf.cast(self._ground_truth[:, time], tf.int32), lambda: tf.ones_like(self._ground_truth[:, 0], dtype=tf.int32) * self._vocab.eos_token_id) self_feeding = lambda: SampleEmbeddingHelper.sample( self, time, outputs, state, name) reward_feeding = lambda: self._sample_by_reward(time, state) sample_ids = tf.cond( tf.logical_or(tf.equal(time, 0), tf.equal(sample_method_id, 1)), truth_feeding, lambda: tf.cond( tf.equal(sample_method_id, 2), reward_feeding, self_feeding)) return sample_ids
def next_inputs(self, time, outputs, state, sample_ids, name=None): """ notice the special form of 'state'([decoded_ids, rnn_state]) """ finished, next_inputs, next_state = SampleEmbeddingHelper.next_inputs( self, time, outputs, state[1], sample_ids, name) next_state = [tf.concat( [state[0][:, :time], tf.expand_dims(sample_ids, 1), state[0][:, time + 1:]], axis=1), next_state] next_state[0] = tf.reshape(next_state[0], (tf.shape(sample_ids)[0], 60)) return finished, next_inputs, next_state
def _init(self, sequence, targets, authors): batch_size = tf.shape(sequence)[0] sequence_lengths = tf.cast(tf.count_nonzero(sequence, axis=1), tf.int32) embedding = tf.Variable( tf.random_normal((self._vocab_size, self._embed_size)), name='char_embedding' ) context = tf.Variable( tf.random_normal((self._author_size, self._ctx_size)), name='ctx_embedding' ) embedded_sequence = tf.nn.embedding_lookup(embedding, sequence) embedded_authors = tf.nn.embedding_lookup(context, authors) gpu = lambda x: '/gpu:{}'.format(x % self._num_gpu) if self._training: dropout = lambda x: DropoutWrapper( x, 1.0-self._input_dropout, 1.0-self._output_dropout) helper = TrainingHelper(embedded_sequence, sequence_lengths) else: dropout = lambda x: x helper = SampleEmbeddingHelper(embedding, sequence[:,0], 2) base = lambda x: ContextWrapper(self._cell(x), embedded_authors) wrap = lambda i, cell: DeviceWrapper(dropout(cell), gpu(i)) cells = [wrap(i, base(self._cell_size)) for i in range(self._cell_num)] cell = MultiRNNCell(cells) init_state = cell.zero_state(batch_size, tf.float32) dense = tf.layers.Dense( self._vocab_size, self._activation, name='fully_connected' ) decoder = BasicDecoder(cell, helper, init_state, dense) output, _, _ = dynamic_decode(decoder, swap_memory=True) logits = output.rnn_output weights = tf.sequence_mask(sequence_lengths, dtype=tf.float32) loss = tf.contrib.seq2seq.sequence_loss( logits, targets, weights ) out = output.sample_id return targets, loss, out
def _init(self): sequence = tf.placeholder(tf.int32, [None, None], name='sequence') targets = tf.placeholder(tf.int32, [None, None], name='targets') authors = tf.placeholder(tf.int32, [None, None], name='authors') batch_size = tf.shape(sequence)[0] sequence_lengths = tf.cast(tf.count_nonzero(sequence, axis=1), tf.int32) embedding = tf.Variable( tf.random_normal((self._vocab_size, self._embed_size))) context = tf.Variable( tf.random_normal((self._author_size, self._ctx_size))) embedded_sequence = tf.nn.embedding_lookup(embedding, sequence) embedded_authors = tf.nn.embedding_lookup(context, authors) one_hot_targets = tf.one_hot(targets, self._vocab_size) gpu = lambda x: str(x % self._num_gpu) if self._attn: mech = BahdanauAttention(self._attn_depth, embedded_sequence, sequence_lengths) attn_cell = lambda x: DeviceWrapper( AttentionWrapper(x, mech, self._attn_size), "/gpu:" + gpu(1)) else: attn_cell = lambda x: x if self._training: dropout = lambda x: DropoutWrapper(x, 1.0, 1.0 - self._dropout) else: dropout = lambda x: x if self._cell == 'lstm': base_cell = lambda x: dropout(BasicLSTMCell(x)) elif self._cell == 'gru': base_cell = lambda x: dropout(GRUCell(x)) context_cell = ContextWrapper( base_cell(self._cell_size), embedded_authors, ) #context_cell = base_cell(self._cell_size) bottom_cell = DeviceWrapper(attn_cell(context_cell), "/gpu:0") top_cells = [ DeviceWrapper(base_cell(self._cell_size), "/gpu:" + gpu(i)) for i in range(1, self._cell_num) ] cell = MultiRNNCell([bottom_cell] + top_cells) init_state = cell.zero_state(batch_size, tf.float32) if self._training: helper = TrainingHelper(embedded_sequence, sequence_lengths) else: helper = SampleEmbeddingHelper(embedding, sequence[:, 0], 1) dense = Dense(self._vocab_size, self._activation) decoder = BasicDecoder(cell, helper, init_state, dense) output, state, _ = dynamic_decode(decoder, swap_memory=True) logits = output.rnn_output loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=one_hot_targets) loss = tf.reduce_mean(loss) out = tf.nn.softmax(logits) return sequence, authors, targets, loss, out