def testStepWithGreedyEmbeddingHelper(self): batch_size = 5 vocabulary_size = 7 cell_depth = vocabulary_size # cell's logits must match vocabulary size input_depth = 10 start_tokens = np.random.randint(0, vocabulary_size, size=batch_size) end_token = 1 with self.session(use_gpu=True) as sess: embeddings = np.random.randn(vocabulary_size, input_depth).astype(np.float32) cell = rnn_cell.LSTMCell(vocabulary_size) helper = helper_py.GreedyEmbeddingHelper(embeddings, start_tokens, end_token) my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( basic_decoder.BasicDecoderOutput(cell_depth, tensor_shape.TensorShape([])), output_size) self.assertEqual( basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), output_dtype) (first_finished, first_inputs, first_state) = my_decoder.initialize() (step_outputs, step_state, step_next_inputs, step_finished) = my_decoder.step( constant_op.constant(0), first_inputs, first_state) batch_size_t = my_decoder.batch_size self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) self.assertEqual((batch_size,), step_outputs[1].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) sess.run(variables.global_variables_initializer()) sess_results = sess.run({ "batch_size": batch_size_t, "first_finished": first_finished, "first_inputs": first_inputs, "first_state": first_state, "step_outputs": step_outputs, "step_state": step_state, "step_next_inputs": step_next_inputs, "step_finished": step_finished }) expected_sample_ids = np.argmax( sess_results["step_outputs"].rnn_output, -1) expected_step_finished = (expected_sample_ids == end_token) expected_step_next_inputs = embeddings[expected_sample_ids] self.assertAllEqual([False, False, False, False, False], sess_results["first_finished"]) self.assertAllEqual(expected_step_finished, sess_results["step_finished"]) self.assertEqual(output_dtype.sample_id, sess_results["step_outputs"].sample_id.dtype) self.assertAllEqual(expected_sample_ids, sess_results["step_outputs"].sample_id) self.assertAllEqual(expected_step_next_inputs, sess_results["step_next_inputs"])
def _build_decoder(self, encoder_outputs, encoder_state): with tf.name_scope("seq_decoder"): batch_size = self.batch_size # sequence_length = tf.fill([self.batch_size], self.num_steps) if self.mode == tf.contrib.learn.ModeKeys.TRAIN: sequence_length = self.iterator.target_length else: sequence_length = self.iterator.source_length if (self.mode != tf.contrib.learn.ModeKeys.TRAIN) and self.beam_width > 1: batch_size = batch_size * self.beam_width encoder_outputs = beam_search_decoder.tile_batch( encoder_outputs, multiplier=self.beam_width) encoder_state = nest.map_structure( lambda s: beam_search_decoder.tile_batch( s, self.beam_width), encoder_state) sequence_length = beam_search_decoder.tile_batch( sequence_length, multiplier=self.beam_width) single_cell = single_rnn_cell(self.hparams.unit_type, self.num_units, self.dropout) decoder_cell = MultiRNNCell( [single_cell for _ in range(self.num_layers_decoder)]) decoder_cell = InputProjectionWrapper(decoder_cell, num_proj=self.num_units) attention_mechanism = create_attention_mechanism( self.hparams.attention_mechanism, self.num_units, memory=encoder_outputs, source_sequence_length=sequence_length) decoder_cell = wrapper.AttentionWrapper( decoder_cell, attention_mechanism, attention_layer_size=self.num_units, output_attention=True, alignment_history=False) # AttentionWrapperState의 cell_state를 encoder의 state으로 설정한다. initial_state = decoder_cell.zero_state(batch_size=batch_size, dtype=tf.float32) embeddings_decoder = tf.get_variable( "embedding_decoder", [self.num_decoder_symbols, self.num_units], initializer=self.initializer, dtype=tf.float32) output_layer = Dense(units=self.num_decoder_symbols, use_bias=True, name="output_layer") if self.mode == tf.contrib.learn.ModeKeys.TRAIN: decoder_inputs = tf.nn.embedding_lookup( embeddings_decoder, self.iterator.target_in) decoder_helper = helper.TrainingHelper( decoder_inputs, sequence_length=sequence_length) dec = basic_decoder.BasicDecoder(decoder_cell, decoder_helper, initial_state, output_layer=output_layer) final_outputs, final_state, _ = decoder.dynamic_decode(dec) output_ids = final_outputs.rnn_output outputs = final_outputs.sample_id else: def embedding_fn(inputs): return tf.nn.embedding_lookup(embeddings_decoder, inputs) decoding_length_factor = 2.0 max_encoder_length = tf.reduce_max(self.iterator.source_length) maximum_iterations = tf.to_int32( tf.round( tf.to_float(max_encoder_length) * decoding_length_factor)) tgt_sos_id = tf.cast( self.tgt_vocab_table.lookup(tf.constant(self.hparams.sos)), tf.int32) tgt_eos_id = tf.cast( self.tgt_vocab_table.lookup(tf.constant(self.hparams.eos)), tf.int32) start_tokens = tf.fill([self.batch_size], tgt_sos_id) end_token = tgt_eos_id if self.beam_width == 1: decoder_helper = helper.GreedyEmbeddingHelper( embedding=embedding_fn, start_tokens=start_tokens, end_token=end_token) dec = basic_decoder.BasicDecoder(decoder_cell, decoder_helper, initial_state, output_layer=output_layer) else: dec = beam_search_decoder.BeamSearchDecoder( cell=decoder_cell, embedding=embedding_fn, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state, output_layer=output_layer, beam_width=self.beam_width) final_outputs, final_state, _ = decoder.dynamic_decode( dec, # swap_memory=True, maximum_iterations=maximum_iterations) if self.mode == tf.contrib.learn.ModeKeys.TRAIN or self.beam_width == 1: output_ids = final_outputs.sample_id outputs = final_outputs.rnn_output else: output_ids = final_outputs.predicted_ids outputs = final_outputs.beam_search_decoder_output.scores return output_ids, outputs
def build(self): conf = self.conf name = self.name job_type = self.job_type dtype = self.dtype # Input maps self.in_table = lookup.MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=UNK_ID, shared_name="in_table", name="in_table", checkpoint=True) self.topic_in_table = lookup.MutableHashTable( key_dtype=tf.string, value_dtype=tf.int64, default_value=2, shared_name="topic_in_table", name="topic_in_table", checkpoint=True) self.out_table = lookup.MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value="_UNK", shared_name="out_table", name="out_table", checkpoint=True) graphlg.info("Creating placeholders...") self.enc_str_inps = tf.placeholder(tf.string, shape=(None, conf.input_max_len), name="enc_inps") self.enc_lens = tf.placeholder(tf.int32, shape=[None], name="enc_lens") self.enc_str_topics = tf.placeholder(tf.string, shape=(None, None), name="enc_topics") self.dec_str_inps = tf.placeholder( tf.string, shape=[None, conf.output_max_len + 2], name="dec_inps") self.dec_lens = tf.placeholder(tf.int32, shape=[None], name="dec_lens") # table lookup self.enc_inps = self.in_table.lookup(self.enc_str_inps) self.enc_topics = self.topic_in_table.lookup(self.enc_str_topics) self.dec_inps = self.in_table.lookup(self.dec_str_inps) batch_size = tf.shape(self.enc_inps)[0] with variable_scope.variable_scope(self.model_kind, dtype=dtype) as scope: # Create encode graph and get attn states graphlg.info("Creating embeddings and do lookup...") t_major_enc_inps = tf.transpose(self.enc_inps) with ops.device("/cpu:0"): self.embedding = variable_scope.get_variable( "embedding", [conf.input_vocab_size, conf.embedding_size]) self.emb_enc_inps = embedding_lookup_unique( self.embedding, t_major_enc_inps) self.topic_embedding = variable_scope.get_variable( "topic_embedding", [conf.topic_vocab_size, conf.topic_embedding_size], trainable=False) self.emb_enc_topics = embedding_lookup_unique( self.topic_embedding, self.enc_topics) graphlg.info("Creating out projection weights...") if conf.out_layer_size != None: w = tf.get_variable( "proj_w", [conf.out_layer_size, conf.output_vocab_size], dtype=dtype) else: w = tf.get_variable("proj_w", [conf.num_units, conf.output_vocab_size], dtype=dtype) b = tf.get_variable("proj_b", [conf.output_vocab_size], dtype=dtype) self.out_proj = (w, b) graphlg.info("Creating encoding dynamic rnn...") with variable_scope.variable_scope("encoder", dtype=dtype) as scope: if conf.bidirectional: cell_fw = CreateMultiRNNCell(conf.cell_model, conf.num_units, conf.num_layers, conf.output_keep_prob) cell_bw = CreateMultiRNNCell(conf.cell_model, conf.num_units, conf.num_layers, conf.output_keep_prob) self.enc_outs, self.enc_states = bidirectional_dynamic_rnn( cell_fw=cell_fw, cell_bw=cell_bw, inputs=self.emb_enc_inps, sequence_length=self.enc_lens, dtype=dtype, parallel_iterations=16, time_major=True, scope=scope) fw_s, bw_s = self.enc_states self.enc_states = tuple([ tf.concat([f, b], axis=1) for f, b in zip(fw_s, bw_s) ]) self.enc_outs = tf.concat( [self.enc_outs[0], self.enc_outs[1]], axis=2) else: cell = CreateMultiRNNCell(conf.cell_model, conf.num_units, conf.num_layers, conf.output_keep_prob) self.enc_outs, self.enc_states = dynamic_rnn( cell=cell, inputs=self.emb_enc_inps, sequence_length=self.enc_lens, parallel_iterations=16, scope=scope, dtype=dtype, time_major=True) attn_len = tf.shape(self.enc_outs)[0] graphlg.info("Preparing init attention and states for decoder...") initial_state = self.enc_states attn_states = tf.transpose(self.enc_outs, perm=[1, 0, 2]) attn_size = self.conf.num_units topic_attn_size = self.conf.num_units k = tf.get_variable( "topic_proj", [1, 1, self.conf.topic_embedding_size, topic_attn_size]) topic_attn_states = nn_ops.conv2d( tf.expand_dims(self.emb_enc_topics, 2), k, [1, 1, 1, 1], "SAME") topic_attn_states = tf.squeeze(topic_attn_states, axis=2) graphlg.info("Creating decoder cell...") with variable_scope.variable_scope("decoder", dtype=dtype) as scope: cell = CreateMultiRNNCell(conf.cell_model, attn_size, conf.num_layers, conf.output_keep_prob) # topic if not self.for_deploy: graphlg.info( "Embedding decoder inps, tars and tar weights...") t_major_dec_inps = tf.transpose(self.dec_inps) t_major_tars = tf.slice(t_major_dec_inps, [1, 0], [conf.output_max_len + 1, -1]) t_major_dec_inps = tf.slice(t_major_dec_inps, [0, 0], [conf.output_max_len + 1, -1]) t_major_tar_wgts = tf.cumsum(tf.one_hot( self.dec_lens - 1, conf.output_max_len + 1, axis=0), axis=0, reverse=True) with ops.device("/cpu:0"): emb_dec_inps = embedding_lookup_unique( self.embedding, t_major_dec_inps) hp_train = helper.ScheduledEmbeddingTrainingHelper( inputs=emb_dec_inps, sequence_length=self.enc_lens, embedding=self.embedding, sampling_probability=0.0, out_proj=self.out_proj, except_ids=None, time_major=True) output_layer = None my_decoder = AttnTopicDecoder( cell=cell, helper=hp_train, initial_state=initial_state, attn_states=attn_states, attn_size=attn_size, topic_attn_states=topic_attn_states, topic_attn_size=topic_attn_size, output_layer=output_layer) t_major_cell_outs, final_state = decoder.dynamic_decode( decoder=my_decoder, output_time_major=True, maximum_iterations=conf.output_max_len + 1, scope=scope) t_major_outs = t_major_cell_outs.rnn_output # Branch 1 for debugging, doesn't have to be called self.outputs = tf.transpose(t_major_outs, perm=[1, 0, 2]) L = tf.shape(self.outputs)[1] w, b = self.out_proj self.outputs = tf.reshape(self.outputs, [-1, int(w.shape[0])]) self.outputs = tf.matmul(self.outputs, w) + b # For masking the except_ids when debuging #m = tf.shape(self.outputs)[0] #self.mask = tf.zeros([m, int(w.shape[1])]) #for i in [3]: # self.mask = self.mask + tf.one_hot(indices=tf.ones([m], dtype=tf.int32) * i, on_value=100.0, depth=int(w.shape[1])) #self.outputs = self.outputs - self.mask self.outputs = tf.argmax(self.outputs, axis=1) self.outputs = tf.reshape(self.outputs, [-1, L]) self.outputs = self.out_table.lookup( tf.cast(self.outputs, tf.int64)) # Branch 2 for loss self.loss = dyn_sequence_loss(self.conf, t_major_outs, self.out_proj, t_major_tars, t_major_tar_wgts) self.summary = tf.summary.scalar("%s/loss" % self.name, self.loss) # backpropagation self.build_backprop(self.loss, conf, dtype) #saver self.trainable_params.extend(tf.trainable_variables() + [self.topic_embedding]) need_to_save = self.global_params + self.trainable_params + self.optimizer_params + tf.get_default_graph( ).get_collection("saveable_objects") + [ self.topic_embedding ] self.saver = tf.train.Saver(need_to_save, max_to_keep=conf.max_to_keep) else: hp_infer = helper.GreedyEmbeddingHelper( embedding=self.embedding, start_tokens=tf.ones(shape=[batch_size], dtype=tf.int32), end_token=EOS_ID, out_proj=self.out_proj) output_layer = None #layers_core.Dense(self.conf.outproj_from_size, use_bias=True) my_decoder = AttnTopicDecoder( cell=cell, helper=hp_infer, initial_state=initial_state, attn_states=attn_states, attn_size=attn_size, topic_attn_states=topic_attn_states, topic_attn_size=topic_attn_size, output_layer=output_layer) cell_outs, final_state = decoder.dynamic_decode( decoder=my_decoder, scope=scope, maximum_iterations=40) self.outputs = cell_outs.sample_id #lookup self.outputs = self.out_table.lookup( tf.cast(self.outputs, tf.int64)) #saver self.trainable_params.extend(tf.trainable_variables()) self.saver = tf.train.Saver(max_to_keep=conf.max_to_keep) # Exporter for serving self.model_exporter = exporter.Exporter(self.saver) inputs = { "enc_inps": self.enc_str_inps, "enc_lens": self.enc_lens } outputs = {"out": self.outputs} self.model_exporter.init( tf.get_default_graph().as_graph_def(), named_graph_signatures={ "inputs": exporter.generic_signature(inputs), "outputs": exporter.generic_signature(outputs) }) graphlg.info("Graph done") graphlg.info("") self.dec_states = final_state
tgt_inputs = tf.random.normal((batch_size, tgt_max_times, num_units), dtype=tf.float32) training_helper = helper_py.TrainingHelper(tgt_inputs, tgt_len) # train helper train_decoder = basic_decoder.BasicDecoder( cell=attnRNNCell, helper=training_helper, initial_state=attnRNNCell.zero_state(batch_size, tf.float32)) # inference embedding = tf.get_variable("embedding", shape=(10, 16), initializer=tf.random_uniform_initializer()) infer_helper = helper_py.GreedyEmbeddingHelper( embedding=embedding, # 可以是callable,也可以是embedding矩阵 start_tokens=tf.zeros([batch_size], dtype=tf.int32), end_token=9) infer_decoder = basic_decoder.BasicDecoder( cell=attnRNNCell, helper=infer_helper, initial_state=attnRNNCell.zero_state(batch_size, tf.float32)) final_outputs, final_state, final_sequence_lengths = decoder.dynamic_decode( train_decoder, maximum_iterations=False) print(final_outputs.rnn_output) print(final_outputs.sample_id) print(final_state.cell_state) print(final_sequence_lengths) print("--------infer-------------") final_outputs, final_state, final_sequence_lengths = decoder.dynamic_decode(