def testStepWithScheduledEmbeddingTrainingHelper(self): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 input_depth = 7 vocabulary_size = 10 with self.session(use_gpu=True) as sess: inputs = np.random.randn( batch_size, max_time, input_depth).astype(np.float32) embeddings = np.random.randn( vocabulary_size, input_depth).astype(np.float32) half = constant_op.constant(0.5) cell = rnn_cell.LSTMCell(vocabulary_size) helper = helper_py.ScheduledEmbeddingTrainingHelper( inputs=inputs, sequence_length=sequence_length, embedding=embeddings, sampling_probability=half, time_major=False) my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( basic_decoder.BasicDecoderOutput(vocabulary_size, tensor_shape.TensorShape([])), output_size) self.assertEqual( basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), output_dtype) (first_finished, first_inputs, first_state) = my_decoder.initialize() (step_outputs, step_state, step_next_inputs, step_finished) = my_decoder.step( constant_op.constant(0), first_inputs, first_state) batch_size_t = my_decoder.batch_size self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, vocabulary_size), step_outputs[0].get_shape()) self.assertEqual((batch_size,), step_outputs[1].get_shape()) self.assertEqual((batch_size, vocabulary_size), first_state[0].get_shape()) self.assertEqual((batch_size, vocabulary_size), first_state[1].get_shape()) self.assertEqual((batch_size, vocabulary_size), step_state[0].get_shape()) self.assertEqual((batch_size, vocabulary_size), step_state[1].get_shape()) self.assertEqual((batch_size, input_depth), step_next_inputs.get_shape()) sess.run(variables.global_variables_initializer()) sess_results = sess.run({ "batch_size": batch_size_t, "first_finished": first_finished, "first_inputs": first_inputs, "first_state": first_state, "step_outputs": step_outputs, "step_state": step_state, "step_next_inputs": step_next_inputs, "step_finished": step_finished }) self.assertAllEqual([False, False, False, False, True], sess_results["first_finished"]) self.assertAllEqual([False, False, False, True, True], sess_results["step_finished"]) sample_ids = sess_results["step_outputs"].sample_id self.assertEqual(output_dtype.sample_id, sample_ids.dtype) batch_where_not_sampling = np.where(sample_ids == -1) batch_where_sampling = np.where(sample_ids > -1) self.assertAllClose( sess_results["step_next_inputs"][batch_where_sampling], embeddings[sample_ids[batch_where_sampling]]) self.assertAllClose( sess_results["step_next_inputs"][batch_where_not_sampling], np.squeeze(inputs[batch_where_not_sampling, 1]))
def build(self): conf = self.conf name = self.name job_type = self.job_type dtype = self.dtype # Input maps self.in_table = lookup.MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64, default_value=UNK_ID, shared_name="in_table", name="in_table", checkpoint=True) self.topic_in_table = lookup.MutableHashTable( key_dtype=tf.string, value_dtype=tf.int64, default_value=2, shared_name="topic_in_table", name="topic_in_table", checkpoint=True) self.out_table = lookup.MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string, default_value="_UNK", shared_name="out_table", name="out_table", checkpoint=True) graphlg.info("Creating placeholders...") self.enc_str_inps = tf.placeholder(tf.string, shape=(None, conf.input_max_len), name="enc_inps") self.enc_lens = tf.placeholder(tf.int32, shape=[None], name="enc_lens") self.enc_str_topics = tf.placeholder(tf.string, shape=(None, None), name="enc_topics") self.dec_str_inps = tf.placeholder( tf.string, shape=[None, conf.output_max_len + 2], name="dec_inps") self.dec_lens = tf.placeholder(tf.int32, shape=[None], name="dec_lens") # table lookup self.enc_inps = self.in_table.lookup(self.enc_str_inps) self.enc_topics = self.topic_in_table.lookup(self.enc_str_topics) self.dec_inps = self.in_table.lookup(self.dec_str_inps) batch_size = tf.shape(self.enc_inps)[0] with variable_scope.variable_scope(self.model_kind, dtype=dtype) as scope: # Create encode graph and get attn states graphlg.info("Creating embeddings and do lookup...") t_major_enc_inps = tf.transpose(self.enc_inps) with ops.device("/cpu:0"): self.embedding = variable_scope.get_variable( "embedding", [conf.input_vocab_size, conf.embedding_size]) self.emb_enc_inps = embedding_lookup_unique( self.embedding, t_major_enc_inps) self.topic_embedding = variable_scope.get_variable( "topic_embedding", [conf.topic_vocab_size, conf.topic_embedding_size], trainable=False) self.emb_enc_topics = embedding_lookup_unique( self.topic_embedding, self.enc_topics) graphlg.info("Creating out projection weights...") if conf.out_layer_size != None: w = tf.get_variable( "proj_w", [conf.out_layer_size, conf.output_vocab_size], dtype=dtype) else: w = tf.get_variable("proj_w", [conf.num_units, conf.output_vocab_size], dtype=dtype) b = tf.get_variable("proj_b", [conf.output_vocab_size], dtype=dtype) self.out_proj = (w, b) graphlg.info("Creating encoding dynamic rnn...") with variable_scope.variable_scope("encoder", dtype=dtype) as scope: if conf.bidirectional: cell_fw = CreateMultiRNNCell(conf.cell_model, conf.num_units, conf.num_layers, conf.output_keep_prob) cell_bw = CreateMultiRNNCell(conf.cell_model, conf.num_units, conf.num_layers, conf.output_keep_prob) self.enc_outs, self.enc_states = bidirectional_dynamic_rnn( cell_fw=cell_fw, cell_bw=cell_bw, inputs=self.emb_enc_inps, sequence_length=self.enc_lens, dtype=dtype, parallel_iterations=16, time_major=True, scope=scope) fw_s, bw_s = self.enc_states self.enc_states = tuple([ tf.concat([f, b], axis=1) for f, b in zip(fw_s, bw_s) ]) self.enc_outs = tf.concat( [self.enc_outs[0], self.enc_outs[1]], axis=2) else: cell = CreateMultiRNNCell(conf.cell_model, conf.num_units, conf.num_layers, conf.output_keep_prob) self.enc_outs, self.enc_states = dynamic_rnn( cell=cell, inputs=self.emb_enc_inps, sequence_length=self.enc_lens, parallel_iterations=16, scope=scope, dtype=dtype, time_major=True) attn_len = tf.shape(self.enc_outs)[0] graphlg.info("Preparing init attention and states for decoder...") initial_state = self.enc_states attn_states = tf.transpose(self.enc_outs, perm=[1, 0, 2]) attn_size = self.conf.num_units topic_attn_size = self.conf.num_units k = tf.get_variable( "topic_proj", [1, 1, self.conf.topic_embedding_size, topic_attn_size]) topic_attn_states = nn_ops.conv2d( tf.expand_dims(self.emb_enc_topics, 2), k, [1, 1, 1, 1], "SAME") topic_attn_states = tf.squeeze(topic_attn_states, axis=2) graphlg.info("Creating decoder cell...") with variable_scope.variable_scope("decoder", dtype=dtype) as scope: cell = CreateMultiRNNCell(conf.cell_model, attn_size, conf.num_layers, conf.output_keep_prob) # topic if not self.for_deploy: graphlg.info( "Embedding decoder inps, tars and tar weights...") t_major_dec_inps = tf.transpose(self.dec_inps) t_major_tars = tf.slice(t_major_dec_inps, [1, 0], [conf.output_max_len + 1, -1]) t_major_dec_inps = tf.slice(t_major_dec_inps, [0, 0], [conf.output_max_len + 1, -1]) t_major_tar_wgts = tf.cumsum(tf.one_hot( self.dec_lens - 1, conf.output_max_len + 1, axis=0), axis=0, reverse=True) with ops.device("/cpu:0"): emb_dec_inps = embedding_lookup_unique( self.embedding, t_major_dec_inps) hp_train = helper.ScheduledEmbeddingTrainingHelper( inputs=emb_dec_inps, sequence_length=self.enc_lens, embedding=self.embedding, sampling_probability=0.0, out_proj=self.out_proj, except_ids=None, time_major=True) output_layer = None my_decoder = AttnTopicDecoder( cell=cell, helper=hp_train, initial_state=initial_state, attn_states=attn_states, attn_size=attn_size, topic_attn_states=topic_attn_states, topic_attn_size=topic_attn_size, output_layer=output_layer) t_major_cell_outs, final_state = decoder.dynamic_decode( decoder=my_decoder, output_time_major=True, maximum_iterations=conf.output_max_len + 1, scope=scope) t_major_outs = t_major_cell_outs.rnn_output # Branch 1 for debugging, doesn't have to be called self.outputs = tf.transpose(t_major_outs, perm=[1, 0, 2]) L = tf.shape(self.outputs)[1] w, b = self.out_proj self.outputs = tf.reshape(self.outputs, [-1, int(w.shape[0])]) self.outputs = tf.matmul(self.outputs, w) + b # For masking the except_ids when debuging #m = tf.shape(self.outputs)[0] #self.mask = tf.zeros([m, int(w.shape[1])]) #for i in [3]: # self.mask = self.mask + tf.one_hot(indices=tf.ones([m], dtype=tf.int32) * i, on_value=100.0, depth=int(w.shape[1])) #self.outputs = self.outputs - self.mask self.outputs = tf.argmax(self.outputs, axis=1) self.outputs = tf.reshape(self.outputs, [-1, L]) self.outputs = self.out_table.lookup( tf.cast(self.outputs, tf.int64)) # Branch 2 for loss self.loss = dyn_sequence_loss(self.conf, t_major_outs, self.out_proj, t_major_tars, t_major_tar_wgts) self.summary = tf.summary.scalar("%s/loss" % self.name, self.loss) # backpropagation self.build_backprop(self.loss, conf, dtype) #saver self.trainable_params.extend(tf.trainable_variables() + [self.topic_embedding]) need_to_save = self.global_params + self.trainable_params + self.optimizer_params + tf.get_default_graph( ).get_collection("saveable_objects") + [ self.topic_embedding ] self.saver = tf.train.Saver(need_to_save, max_to_keep=conf.max_to_keep) else: hp_infer = helper.GreedyEmbeddingHelper( embedding=self.embedding, start_tokens=tf.ones(shape=[batch_size], dtype=tf.int32), end_token=EOS_ID, out_proj=self.out_proj) output_layer = None #layers_core.Dense(self.conf.outproj_from_size, use_bias=True) my_decoder = AttnTopicDecoder( cell=cell, helper=hp_infer, initial_state=initial_state, attn_states=attn_states, attn_size=attn_size, topic_attn_states=topic_attn_states, topic_attn_size=topic_attn_size, output_layer=output_layer) cell_outs, final_state = decoder.dynamic_decode( decoder=my_decoder, scope=scope, maximum_iterations=40) self.outputs = cell_outs.sample_id #lookup self.outputs = self.out_table.lookup( tf.cast(self.outputs, tf.int64)) #saver self.trainable_params.extend(tf.trainable_variables()) self.saver = tf.train.Saver(max_to_keep=conf.max_to_keep) # Exporter for serving self.model_exporter = exporter.Exporter(self.saver) inputs = { "enc_inps": self.enc_str_inps, "enc_lens": self.enc_lens } outputs = {"out": self.outputs} self.model_exporter.init( tf.get_default_graph().as_graph_def(), named_graph_signatures={ "inputs": exporter.generic_signature(inputs), "outputs": exporter.generic_signature(outputs) }) graphlg.info("Graph done") graphlg.info("") self.dec_states = final_state