def create_pretrain_network(self): with tf.variable_scope('attention'): (self.attention_keys, self.attention_values, _, self.attention_construct_fn) = attention_utils.prepare_attention( self.encoded_seed, 'luong', num_units=self.hidden_dim, reuse=True) # supervised pretraining for generator g_predictions = tensor_array_ops.TensorArray(dtype=tf.float32, size=self.sequence_length, dynamic_size=False, infer_shape=True) # processed for batch with tf.device("/cpu:0"): self.processed_x = tf.transpose( tf.nn.embedding_lookup(self.g_embeddings, self.x), perm=[1, 0, 2]) # seq_length x batch_size x emb_dim ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32, size=self.sequence_length) ta_emb_x = ta_emb_x.unstack(self.processed_x) output_mask = None if self.is_training: output_mask = make_mask(self.batch_size, self.gen_vd_keep_prob, self.hidden_dim) def _pretrain_recurrence(i, x_t, prev_h_state, g_predictions): out, state = self.g_recurrent_unit(x_t, prev_h_state) out = self.attention_construct_fn(out, self.attention_keys, self.attention_values) if output_mask is not None: out *= output_mask o_t = self.g_output_unit(out) # batch x vocab , logits not prob g_predictions = g_predictions.write( i, tf.nn.softmax(o_t)) # batch x vocab_size x_tp1 = ta_emb_x.read(i) return i + 1, x_tp1, state, g_predictions _, _, _, self.g_predictions = control_flow_ops.while_loop( cond=lambda i, _1, _2, _3: i < self.sequence_length, body=_pretrain_recurrence, loop_vars=(tf.constant(0, dtype=tf.int32), tf.nn.embedding_lookup(self.g_embeddings, self.start_token), self.h0, g_predictions)) self.g_predictions = self.g_predictions.stack() self.g_predictions = tf.transpose( self.g_predictions, perm=[1, 0, 2]) # batch_size x seq_length x vocab_size print(self.g_predictions.shape)
def create(self, reuse=None): g_embeddings = self.create_embedding() init_missing_embedding = tf.random_uniform([1, self.emb_dim], -1.0, 1.0) missing_embedding = tf.get_variable("missing_embedding", initializer=init_missing_embedding) self.g_embeddings = tf.concat([g_embeddings, missing_embedding], axis=0) attn_cell = self.create_cell(reuse=reuse) with tf.variable_scope('seed_encoder') as scope: embedded_seed = tf.nn.embedding_lookup(self.g_embeddings, self.masked_inputs) cell = tf.contrib.rnn.MultiRNNCell( [attn_cell() for _ in range(self.n_rnn_layers)], state_is_tuple=True) state = cell.zero_state(self.batch_size, tf.float32) outputs, state = tf.nn.dynamic_rnn(cell, embedded_seed, initial_state=state, scope=scope) def _make_mask(batch_size, keep_prob, units): random_tensor = keep_prob random_tensor += tf.random_uniform( tf.stack([batch_size, 1, units])) return tf.floor(random_tensor) / keep_prob if self.is_training: output_mask = _make_mask(self.batch_size, self.gen_vd_keep_prob, self.hidden_dim) outputs *= output_mask self.encoded_seed = outputs # Initial states # self.h0 = self.g_recurrent_unit.zero_state(self.batch_size, tf.float32) self.h0 = state with tf.variable_scope('attention'): print("encoded seed:", self.encoded_seed.shape) (self.attention_keys, self.attention_values, _, self.attention_construct_fn) = attention_utils.prepare_attention( self.encoded_seed, 'luong', num_units=self.hidden_dim) with tf.variable_scope('generator'): self.g_recurrent_unit = tf.contrib.rnn.MultiRNNCell( [attn_cell() for _ in range(self.n_rnn_layers)], state_is_tuple=True) self.g_output_unit = self.create_output_unit( ) # maps h_t to o_t (output token logits)
def dis_decoder(generator, embedding, sequence, encoding_state, reuse=None, dis_num_layers=2): sequence = tf.cast(sequence, tf.int32) with tf.variable_scope('decoder', reuse=reuse): def lstm_cell(): return tf.contrib.rnn.BasicLSTMCell(generator.hidden_dim, forget_bias=0.0, state_is_tuple=True, reuse=reuse) attn_cell = lstm_cell if generator.is_training and generator.dis_vd_keep_prob < 1: def attn_cell(): return VariationalDropoutWrapper(lstm_cell(), generator.batch_size, generator.dis_vd_keep_prob, generator.dis_vd_keep_prob) cell_dis = tf.contrib.rnn.MultiRNNCell( [attn_cell() for _ in range(dis_num_layers)], state_is_tuple=True) state = encoding_state[1] (attention_keys, attention_values, _, attention_construct_fn) = attention_utils.prepare_attention( encoding_state[0], 'luong', num_units=generator.hidden_dim, reuse=reuse) if generator.is_training: output_mask = make_mask(generator.batch_size, generator.dis_vd_keep_prob, generator.hidden_dim) with tf.variable_scope('rnn') as vs: predictions = [] rnn_inputs = tf.nn.embedding_lookup(embedding, sequence) for t in range(generator.sequence_length): if t > 0: tf.get_variable_scope().reuse_variables() rnn_in = rnn_inputs[:, t] rnn_out, state = cell_dis(rnn_in, state) rnn_out = attention_construct_fn(rnn_out, attention_keys, attention_values) if generator.is_training: rnn_out *= output_mask pred = tf.contrib.layers.linear(rnn_out, 1, scope=vs) predictions.append(pred) predictions = tf.stack(predictions, axis=1) return tf.squeeze(predictions, axis=2)