Exemple #1
0
  def testStepWithGreedyEmbeddingHelper(self):
    batch_size = 5
    vocabulary_size = 7
    cell_depth = vocabulary_size  # cell's logits must match vocabulary size
    input_depth = 10
    start_tokens = np.random.randint(0, vocabulary_size, size=batch_size)
    end_token = 1

    with self.session(use_gpu=True) as sess:
      embeddings = np.random.randn(vocabulary_size,
                                   input_depth).astype(np.float32)
      cell = rnn_cell.LSTMCell(vocabulary_size)
      helper = helper_py.GreedyEmbeddingHelper(embeddings, start_tokens,
                                               end_token)
      my_decoder = basic_decoder.BasicDecoder(
          cell=cell,
          helper=helper,
          initial_state=cell.zero_state(
              dtype=dtypes.float32, batch_size=batch_size))
      output_size = my_decoder.output_size
      output_dtype = my_decoder.output_dtype
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(cell_depth,
                                           tensor_shape.TensorShape([])),
          output_size)
      self.assertEqual(
          basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
          output_dtype)

      (first_finished, first_inputs, first_state) = my_decoder.initialize()
      (step_outputs, step_state, step_next_inputs,
       step_finished) = my_decoder.step(
           constant_op.constant(0), first_inputs, first_state)
      batch_size_t = my_decoder.batch_size

      self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
      self.assertTrue(
          isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
      self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
      self.assertEqual((batch_size,), step_outputs[1].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
      self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          "batch_size": batch_size_t,
          "first_finished": first_finished,
          "first_inputs": first_inputs,
          "first_state": first_state,
          "step_outputs": step_outputs,
          "step_state": step_state,
          "step_next_inputs": step_next_inputs,
          "step_finished": step_finished
      })

      expected_sample_ids = np.argmax(
          sess_results["step_outputs"].rnn_output, -1)
      expected_step_finished = (expected_sample_ids == end_token)
      expected_step_next_inputs = embeddings[expected_sample_ids]
      self.assertAllEqual([False, False, False, False, False],
                          sess_results["first_finished"])
      self.assertAllEqual(expected_step_finished, sess_results["step_finished"])
      self.assertEqual(output_dtype.sample_id,
                       sess_results["step_outputs"].sample_id.dtype)
      self.assertAllEqual(expected_sample_ids,
                          sess_results["step_outputs"].sample_id)
      self.assertAllEqual(expected_step_next_inputs,
                          sess_results["step_next_inputs"])
Exemple #2
0
    def _build_decoder(self, encoder_outputs, encoder_state):
        with tf.name_scope("seq_decoder"):
            batch_size = self.batch_size
            # sequence_length = tf.fill([self.batch_size], self.num_steps)
            if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
                sequence_length = self.iterator.target_length
            else:
                sequence_length = self.iterator.source_length
            if (self.mode !=
                    tf.contrib.learn.ModeKeys.TRAIN) and self.beam_width > 1:
                batch_size = batch_size * self.beam_width
                encoder_outputs = beam_search_decoder.tile_batch(
                    encoder_outputs, multiplier=self.beam_width)
                encoder_state = nest.map_structure(
                    lambda s: beam_search_decoder.tile_batch(
                        s, self.beam_width), encoder_state)
                sequence_length = beam_search_decoder.tile_batch(
                    sequence_length, multiplier=self.beam_width)

            single_cell = single_rnn_cell(self.hparams.unit_type,
                                          self.num_units, self.dropout)
            decoder_cell = MultiRNNCell(
                [single_cell for _ in range(self.num_layers_decoder)])
            decoder_cell = InputProjectionWrapper(decoder_cell,
                                                  num_proj=self.num_units)
            attention_mechanism = create_attention_mechanism(
                self.hparams.attention_mechanism,
                self.num_units,
                memory=encoder_outputs,
                source_sequence_length=sequence_length)
            decoder_cell = wrapper.AttentionWrapper(
                decoder_cell,
                attention_mechanism,
                attention_layer_size=self.num_units,
                output_attention=True,
                alignment_history=False)

            # AttentionWrapperState의 cell_state를 encoder의 state으로 설정한다.
            initial_state = decoder_cell.zero_state(batch_size=batch_size,
                                                    dtype=tf.float32)
            embeddings_decoder = tf.get_variable(
                "embedding_decoder",
                [self.num_decoder_symbols, self.num_units],
                initializer=self.initializer,
                dtype=tf.float32)
            output_layer = Dense(units=self.num_decoder_symbols,
                                 use_bias=True,
                                 name="output_layer")

            if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
                decoder_inputs = tf.nn.embedding_lookup(
                    embeddings_decoder, self.iterator.target_in)
                decoder_helper = helper.TrainingHelper(
                    decoder_inputs, sequence_length=sequence_length)

                dec = basic_decoder.BasicDecoder(decoder_cell,
                                                 decoder_helper,
                                                 initial_state,
                                                 output_layer=output_layer)
                final_outputs, final_state, _ = decoder.dynamic_decode(dec)
                output_ids = final_outputs.rnn_output
                outputs = final_outputs.sample_id
            else:

                def embedding_fn(inputs):
                    return tf.nn.embedding_lookup(embeddings_decoder, inputs)

                decoding_length_factor = 2.0
                max_encoder_length = tf.reduce_max(self.iterator.source_length)
                maximum_iterations = tf.to_int32(
                    tf.round(
                        tf.to_float(max_encoder_length) *
                        decoding_length_factor))

                tgt_sos_id = tf.cast(
                    self.tgt_vocab_table.lookup(tf.constant(self.hparams.sos)),
                    tf.int32)
                tgt_eos_id = tf.cast(
                    self.tgt_vocab_table.lookup(tf.constant(self.hparams.eos)),
                    tf.int32)
                start_tokens = tf.fill([self.batch_size], tgt_sos_id)
                end_token = tgt_eos_id

                if self.beam_width == 1:
                    decoder_helper = helper.GreedyEmbeddingHelper(
                        embedding=embedding_fn,
                        start_tokens=start_tokens,
                        end_token=end_token)
                    dec = basic_decoder.BasicDecoder(decoder_cell,
                                                     decoder_helper,
                                                     initial_state,
                                                     output_layer=output_layer)
                else:
                    dec = beam_search_decoder.BeamSearchDecoder(
                        cell=decoder_cell,
                        embedding=embedding_fn,
                        start_tokens=start_tokens,
                        end_token=end_token,
                        initial_state=initial_state,
                        output_layer=output_layer,
                        beam_width=self.beam_width)
                final_outputs, final_state, _ = decoder.dynamic_decode(
                    dec,
                    # swap_memory=True,
                    maximum_iterations=maximum_iterations)
                if self.mode == tf.contrib.learn.ModeKeys.TRAIN or self.beam_width == 1:
                    output_ids = final_outputs.sample_id
                    outputs = final_outputs.rnn_output
                else:
                    output_ids = final_outputs.predicted_ids
                    outputs = final_outputs.beam_search_decoder_output.scores

            return output_ids, outputs
    def build(self):
        conf = self.conf
        name = self.name
        job_type = self.job_type
        dtype = self.dtype

        # Input maps
        self.in_table = lookup.MutableHashTable(key_dtype=tf.string,
                                                value_dtype=tf.int64,
                                                default_value=UNK_ID,
                                                shared_name="in_table",
                                                name="in_table",
                                                checkpoint=True)

        self.topic_in_table = lookup.MutableHashTable(
            key_dtype=tf.string,
            value_dtype=tf.int64,
            default_value=2,
            shared_name="topic_in_table",
            name="topic_in_table",
            checkpoint=True)

        self.out_table = lookup.MutableHashTable(key_dtype=tf.int64,
                                                 value_dtype=tf.string,
                                                 default_value="_UNK",
                                                 shared_name="out_table",
                                                 name="out_table",
                                                 checkpoint=True)

        graphlg.info("Creating placeholders...")
        self.enc_str_inps = tf.placeholder(tf.string,
                                           shape=(None, conf.input_max_len),
                                           name="enc_inps")
        self.enc_lens = tf.placeholder(tf.int32, shape=[None], name="enc_lens")

        self.enc_str_topics = tf.placeholder(tf.string,
                                             shape=(None, None),
                                             name="enc_topics")

        self.dec_str_inps = tf.placeholder(
            tf.string, shape=[None, conf.output_max_len + 2], name="dec_inps")
        self.dec_lens = tf.placeholder(tf.int32, shape=[None], name="dec_lens")

        # table lookup
        self.enc_inps = self.in_table.lookup(self.enc_str_inps)
        self.enc_topics = self.topic_in_table.lookup(self.enc_str_topics)
        self.dec_inps = self.in_table.lookup(self.dec_str_inps)

        batch_size = tf.shape(self.enc_inps)[0]

        with variable_scope.variable_scope(self.model_kind,
                                           dtype=dtype) as scope:
            # Create encode graph and get attn states
            graphlg.info("Creating embeddings and do lookup...")
            t_major_enc_inps = tf.transpose(self.enc_inps)
            with ops.device("/cpu:0"):
                self.embedding = variable_scope.get_variable(
                    "embedding", [conf.input_vocab_size, conf.embedding_size])
                self.emb_enc_inps = embedding_lookup_unique(
                    self.embedding, t_major_enc_inps)
                self.topic_embedding = variable_scope.get_variable(
                    "topic_embedding",
                    [conf.topic_vocab_size, conf.topic_embedding_size],
                    trainable=False)
                self.emb_enc_topics = embedding_lookup_unique(
                    self.topic_embedding, self.enc_topics)

            graphlg.info("Creating out projection weights...")
            if conf.out_layer_size != None:
                w = tf.get_variable(
                    "proj_w", [conf.out_layer_size, conf.output_vocab_size],
                    dtype=dtype)
            else:
                w = tf.get_variable("proj_w",
                                    [conf.num_units, conf.output_vocab_size],
                                    dtype=dtype)
            b = tf.get_variable("proj_b", [conf.output_vocab_size],
                                dtype=dtype)
            self.out_proj = (w, b)

            graphlg.info("Creating encoding dynamic rnn...")
            with variable_scope.variable_scope("encoder",
                                               dtype=dtype) as scope:
                if conf.bidirectional:
                    cell_fw = CreateMultiRNNCell(conf.cell_model,
                                                 conf.num_units,
                                                 conf.num_layers,
                                                 conf.output_keep_prob)
                    cell_bw = CreateMultiRNNCell(conf.cell_model,
                                                 conf.num_units,
                                                 conf.num_layers,
                                                 conf.output_keep_prob)
                    self.enc_outs, self.enc_states = bidirectional_dynamic_rnn(
                        cell_fw=cell_fw,
                        cell_bw=cell_bw,
                        inputs=self.emb_enc_inps,
                        sequence_length=self.enc_lens,
                        dtype=dtype,
                        parallel_iterations=16,
                        time_major=True,
                        scope=scope)
                    fw_s, bw_s = self.enc_states
                    self.enc_states = tuple([
                        tf.concat([f, b], axis=1) for f, b in zip(fw_s, bw_s)
                    ])
                    self.enc_outs = tf.concat(
                        [self.enc_outs[0], self.enc_outs[1]], axis=2)
                else:
                    cell = CreateMultiRNNCell(conf.cell_model, conf.num_units,
                                              conf.num_layers,
                                              conf.output_keep_prob)
                    self.enc_outs, self.enc_states = dynamic_rnn(
                        cell=cell,
                        inputs=self.emb_enc_inps,
                        sequence_length=self.enc_lens,
                        parallel_iterations=16,
                        scope=scope,
                        dtype=dtype,
                        time_major=True)
            attn_len = tf.shape(self.enc_outs)[0]

            graphlg.info("Preparing init attention and states for decoder...")
            initial_state = self.enc_states
            attn_states = tf.transpose(self.enc_outs, perm=[1, 0, 2])
            attn_size = self.conf.num_units
            topic_attn_size = self.conf.num_units
            k = tf.get_variable(
                "topic_proj",
                [1, 1, self.conf.topic_embedding_size, topic_attn_size])
            topic_attn_states = nn_ops.conv2d(
                tf.expand_dims(self.emb_enc_topics, 2), k, [1, 1, 1, 1],
                "SAME")
            topic_attn_states = tf.squeeze(topic_attn_states, axis=2)

            graphlg.info("Creating decoder cell...")
            with variable_scope.variable_scope("decoder",
                                               dtype=dtype) as scope:
                cell = CreateMultiRNNCell(conf.cell_model, attn_size,
                                          conf.num_layers,
                                          conf.output_keep_prob)
                # topic
                if not self.for_deploy:
                    graphlg.info(
                        "Embedding decoder inps, tars and tar weights...")
                    t_major_dec_inps = tf.transpose(self.dec_inps)
                    t_major_tars = tf.slice(t_major_dec_inps, [1, 0],
                                            [conf.output_max_len + 1, -1])
                    t_major_dec_inps = tf.slice(t_major_dec_inps, [0, 0],
                                                [conf.output_max_len + 1, -1])
                    t_major_tar_wgts = tf.cumsum(tf.one_hot(
                        self.dec_lens - 1, conf.output_max_len + 1, axis=0),
                                                 axis=0,
                                                 reverse=True)
                    with ops.device("/cpu:0"):
                        emb_dec_inps = embedding_lookup_unique(
                            self.embedding, t_major_dec_inps)

                    hp_train = helper.ScheduledEmbeddingTrainingHelper(
                        inputs=emb_dec_inps,
                        sequence_length=self.enc_lens,
                        embedding=self.embedding,
                        sampling_probability=0.0,
                        out_proj=self.out_proj,
                        except_ids=None,
                        time_major=True)

                    output_layer = None
                    my_decoder = AttnTopicDecoder(
                        cell=cell,
                        helper=hp_train,
                        initial_state=initial_state,
                        attn_states=attn_states,
                        attn_size=attn_size,
                        topic_attn_states=topic_attn_states,
                        topic_attn_size=topic_attn_size,
                        output_layer=output_layer)
                    t_major_cell_outs, final_state = decoder.dynamic_decode(
                        decoder=my_decoder,
                        output_time_major=True,
                        maximum_iterations=conf.output_max_len + 1,
                        scope=scope)
                    t_major_outs = t_major_cell_outs.rnn_output

                    # Branch 1 for debugging, doesn't have to be called
                    self.outputs = tf.transpose(t_major_outs, perm=[1, 0, 2])
                    L = tf.shape(self.outputs)[1]
                    w, b = self.out_proj
                    self.outputs = tf.reshape(self.outputs,
                                              [-1, int(w.shape[0])])
                    self.outputs = tf.matmul(self.outputs, w) + b

                    # For masking the except_ids when debuging
                    #m = tf.shape(self.outputs)[0]
                    #self.mask = tf.zeros([m, int(w.shape[1])])
                    #for i in [3]:
                    #    self.mask = self.mask + tf.one_hot(indices=tf.ones([m], dtype=tf.int32) * i, on_value=100.0, depth=int(w.shape[1]))
                    #self.outputs = self.outputs - self.mask

                    self.outputs = tf.argmax(self.outputs, axis=1)
                    self.outputs = tf.reshape(self.outputs, [-1, L])
                    self.outputs = self.out_table.lookup(
                        tf.cast(self.outputs, tf.int64))

                    # Branch 2 for loss
                    self.loss = dyn_sequence_loss(self.conf, t_major_outs,
                                                  self.out_proj, t_major_tars,
                                                  t_major_tar_wgts)
                    self.summary = tf.summary.scalar("%s/loss" % self.name,
                                                     self.loss)

                    # backpropagation
                    self.build_backprop(self.loss, conf, dtype)

                    #saver
                    self.trainable_params.extend(tf.trainable_variables() +
                                                 [self.topic_embedding])
                    need_to_save = self.global_params + self.trainable_params + self.optimizer_params + tf.get_default_graph(
                    ).get_collection("saveable_objects") + [
                        self.topic_embedding
                    ]
                    self.saver = tf.train.Saver(need_to_save,
                                                max_to_keep=conf.max_to_keep)
                else:
                    hp_infer = helper.GreedyEmbeddingHelper(
                        embedding=self.embedding,
                        start_tokens=tf.ones(shape=[batch_size],
                                             dtype=tf.int32),
                        end_token=EOS_ID,
                        out_proj=self.out_proj)

                    output_layer = None  #layers_core.Dense(self.conf.outproj_from_size, use_bias=True)
                    my_decoder = AttnTopicDecoder(
                        cell=cell,
                        helper=hp_infer,
                        initial_state=initial_state,
                        attn_states=attn_states,
                        attn_size=attn_size,
                        topic_attn_states=topic_attn_states,
                        topic_attn_size=topic_attn_size,
                        output_layer=output_layer)
                    cell_outs, final_state = decoder.dynamic_decode(
                        decoder=my_decoder, scope=scope, maximum_iterations=40)
                    self.outputs = cell_outs.sample_id
                    #lookup
                    self.outputs = self.out_table.lookup(
                        tf.cast(self.outputs, tf.int64))

                    #saver
                    self.trainable_params.extend(tf.trainable_variables())
                    self.saver = tf.train.Saver(max_to_keep=conf.max_to_keep)

                    # Exporter for serving
                    self.model_exporter = exporter.Exporter(self.saver)
                    inputs = {
                        "enc_inps": self.enc_str_inps,
                        "enc_lens": self.enc_lens
                    }
                    outputs = {"out": self.outputs}
                    self.model_exporter.init(
                        tf.get_default_graph().as_graph_def(),
                        named_graph_signatures={
                            "inputs": exporter.generic_signature(inputs),
                            "outputs": exporter.generic_signature(outputs)
                        })
                    graphlg.info("Graph done")
                    graphlg.info("")

                self.dec_states = final_state
Exemple #4
0
tgt_inputs = tf.random.normal((batch_size, tgt_max_times, num_units),
                              dtype=tf.float32)
training_helper = helper_py.TrainingHelper(tgt_inputs, tgt_len)

# train helper
train_decoder = basic_decoder.BasicDecoder(
    cell=attnRNNCell,
    helper=training_helper,
    initial_state=attnRNNCell.zero_state(batch_size, tf.float32))

# inference
embedding = tf.get_variable("embedding",
                            shape=(10, 16),
                            initializer=tf.random_uniform_initializer())
infer_helper = helper_py.GreedyEmbeddingHelper(
    embedding=embedding,  # 可以是callable,也可以是embedding矩阵
    start_tokens=tf.zeros([batch_size], dtype=tf.int32),
    end_token=9)
infer_decoder = basic_decoder.BasicDecoder(
    cell=attnRNNCell,
    helper=infer_helper,
    initial_state=attnRNNCell.zero_state(batch_size, tf.float32))
final_outputs, final_state, final_sequence_lengths = decoder.dynamic_decode(
    train_decoder, maximum_iterations=False)

print(final_outputs.rnn_output)
print(final_outputs.sample_id)
print(final_state.cell_state)
print(final_sequence_lengths)

print("--------infer-------------")
final_outputs, final_state, final_sequence_lengths = decoder.dynamic_decode(