示例#1
0
    def beam_search_step(self,
                         input,
                         state,
                         cell,
                         beam_size,
                         attention_construct_fn=None,
                         input_text=None):
        output, state = cell(input, state)

        if hasattr(state, 'alignments'):
            tf.add_to_collection('attention_alignments', state.alignments)
            tf.add_to_collection('beam_search_alignments',
                                 tf.get_collection('attention_alignments')[-1])

        #TODO: this step cause.. attenion decode each step after initalization still need input_text feed
        #will this case attention_keys and attention_values to be recompute(means redo encoding process) each step?
        #can we avoid this? seems no better method,
        #if enocding is slow may be feed attention_keys, attention_values each step
        if not FLAGS.decode_use_alignment:
            if FLAGS.gen_only:
                output_fn = self.output_fn
                logits = output_fn(output)
            else:
                indices = melt.batch_values_to_indices(tf.to_int32(input_text))
                batch_size = melt.get_batch_size(input)

                if FLAGS.copy_only:
                    output_fn_ = self.copy_output_fn
                else:
                    output_fn_ = self.gen_copy_output_fn
                output_fn = lambda cell_output, cell_state: output_fn_(
                    indices, batch_size, cell_output, cell_state)

                logits = output_fn(output, state)

            if FLAGS.gen_copy_switch and FLAGS.switch_after_softmax:
                logprobs = tf.log(logits)
            else:
                logprobs = tf.nn.log_softmax(logits)

            if FLAGS.decode_copy:
                logprobs = melt.gather_cols(logprobs, tf.to_int32(input_text))
        else:
            logits = state.alignments
            logits = scores[:, :tf.shape(input_text)[-1]]
            logprobs = tf.nn.log_softmax(logits)

        top_logprobs, top_ids = tf.nn.top_k(logprobs, beam_size)
        #------too slow... for transfering large data between py and c++ cost a lot!
        #top_logprobs, top_ids = tf.nn.top_k(logprobs, self.vocab_size)

        if input_text is not None and FLAGS.decode_copy:
            top_ids = tf.nn.embedding_lookup(input_text, top_ids)

        if hasattr(state, 'cell_state'):
            state = state.cell_state

        return output, state, top_logprobs, top_ids
示例#2
0
    def generate_sequence_greedy(self,
                                 input,
                                 max_words,
                                 initial_state=None,
                                 attention_states=None,
                                 convert_unk=True,
                                 input_text=None,
                                 emb=None):
        """
    this one is using greedy search method
    for beam search using generate_sequence_by_beam_search with addditional params like beam_size
    """
        if emb is None:
            emb = self.emb

        batch_size = melt.get_batch_size(input)
        if attention_states is None:
            cell = self.cell
        else:
            cell = self.prepare_attention(
                attention_states,
                initial_state=initial_state,
                score_as_alignment=self.score_as_alignment)
            initial_state = None
        state = cell.zero_state(
            batch_size, tf.float32) if initial_state is None else initial_state

        helper = melt.seq2seq.GreedyEmbeddingHelper(embedding=emb,
                                                    first_input=input,
                                                    end_token=self.end_id)

        if FLAGS.gen_only:
            output_fn = self.output_fn
        else:
            indices = melt.batch_values_to_indices(tf.to_int32(input_text))
            if FLAGS.copy_only:
                output_fn_ = self.copy_output_fn
            else:
                output_fn_ = self.gen_copy_output_fn
            output_fn = lambda cell_output, cell_state: output_fn_(
                indices, batch_size, cell_output, cell_state)

        my_decoder = melt.seq2seq.BasicDecoder(cell=cell,
                                               helper=helper,
                                               initial_state=state,
                                               vocab_size=self.vocab_size,
                                               output_fn=output_fn)

        outputs, _, _ = melt.seq2seq.dynamic_decode(
            my_decoder, maximum_iterations=max_words, scope=self.scope)
        generated_sequence = outputs.sample_id
        #------like beam search return sequence, score
        return generated_sequence, tf.zeros([
            batch_size,
        ])
示例#3
0
 def words_importance_encode(self, sequence, emb=None, input=None):
     #[batch_size, emb_dim]
     argmax_values = self.encode(
         sequence, emb, input,
         output_method=melt.rnn.OutputMethod.argmax)[0]
     indices = melt.batch_values_to_indices(tf.to_int32(argmax_values))
     updates = tf.ones_like(argmax_values)
     shape = tf.shape(sequence)
     scores = tf.scatter_nd(indices, updates, shape=shape) * tf.to_int64(
         tf.sequence_mask(self.sequence_length, shape[-1]))
     return scores
示例#4
0
        def gen_copy_output_train(time, indices, targets, sampled_values,
                                  batch_size, cell_output, cell_state):
            if self.softmax_loss_function is not None:
                labels = tf.slice(targets, [0, time], [-1, 1])

                sampled, true_expected_count, sampled_expected_count = sampled_values
                sampled_values = \
                  sampled, tf.slice(tf.reshape(true_expected_count, [batch_size, -1]), [0, time], [-1, 1]), sampled_expected_count

                sampled_ids, sampled_logits = melt.nn.compute_sampled_ids_and_logits(
                    weights=self.w_t,
                    biases=self.v,
                    labels=labels,
                    inputs=cell_output,
                    num_sampled=self.num_sampled,
                    num_classes=self.vocab_size,
                    sampled_values=sampled_values,
                    remove_accidental_hits=False)
                gen_indices = melt.batch_values_to_indices(
                    tf.to_int32(sampled_ids))
                gen_logits = tf.scatter_nd(gen_indices,
                                           sampled_logits,
                                           shape=[batch_size, self.vocab_size])
            else:
                gen_logits = self.output_fn(cell_output)

            copy_logits = copy_output(indices, batch_size, cell_output,
                                      cell_state)

            if FLAGS.gen_copy_switch:
                #gen_copy_switch == True.
                gen_probability = cell_state.gen_probability
                if FLAGS.switch_after_softmax:
                    return gen_probability * tf.nn.softmax(gen_logits) + (
                        1 - gen_probability) * tf.nn.softmax(copy_logits)
                else:
                    return gen_probability * gen_logits + (
                        1 - gen_probability) * copy_logits
            else:
                return gen_logits + copy_logits
示例#5
0
    def decoder_fn(time, cell_state, cell_input, cell_output, context_state):
        """Decoder function used in the `dynamic_rnn_decoder` for training.

    Args:
      time: positive integer constant reflecting the current timestep.
      cell_state: state of RNNCell.
      cell_input: input provided by `dynamic_rnn_decoder`.
      cell_output: output of RNNCell.
      context_state: context state provided by `dynamic_rnn_decoder`.

    Returns:
      A tuple (done, next state, next input, emit output, next context state)
      where:

      done: `None`, which is used by the `dynamic_rnn_decoder` to indicate
      that `sequence_lengths` in `dynamic_rnn_decoder` should be used.

      next state: `cell_state`, this decoder function does not modify the
      given state.

      next input: `cell_input`, this decoder function does not modify the
      given input. The input could be modified when applying e.g. attention.

      emit output: `cell_output`, this decoder function does not modify the
      given output.

      next context state: `context_state`, this decoder function does not
      modify the given context state. The context state could be modified when
      applying e.g. beam search.
    """
        with ops.name_scope(
                name, "attention_decoder_fn_train",
            [time, cell_state, cell_input, cell_output, context_state]):
            #input_text = None

            if cell_state is None:  # first call, return encoder_state
                cell_state = encoder_state

                # init attention
                attention = init_attention(encoder_state)
                if input_text is not None:
                    #cell_output = array_ops.zeros([vocab_size], dtype=dtypes.float32)
                    context_state = tensor_array_ops.TensorArray(
                        dtype=tf.float32,
                        size=0,
                        dynamic_size=True,
                        infer_shape=False)
            else:
                # construct attention
                attention, scores, alignments = attention_construct_fn(
                    cell_output, attention_keys, attention_values)
                #attention, scores, alignments = attention_construct_fn(cell_state.h, attention_keys,
                #                                                       attention_values)
                cell_output = attention
                if input_text is not None:
                    #encoder_info = nest.flatten(encoder_state)[0]
                    #batch_size = encoder_info.get_shape()[0].value
                    #if batch_size is None:
                    #  batch_size = array_ops.shape(encoder_info)[0]
                    ##ref =  tf.Variable(array_ops.zeros([batch_size, vocab_size], dtype=dtypes.float32))
                    ##https://github.com/tensorflow/tensorflow/issues/8604
                    #with tf.control_dependencies(None):
                    ##TODO... must fix batch_size right now
                    #  cell_output_ =  tf.Variable(array_ops.zeros([256, vocab_size], dtype=dtypes.float32))

                    #actually should use tf.scatter_nd, anyway this attention is deprecated!
                    cell_output_ = tf.get_variable(
                        "cell_output_", [256, vocab_size],
                        dtype=dtypes.float32,
                        initializer=tf.zeros_initializer())
                    #cell_output_ =  tf.get_variable("cell_output_", [256, vocab_size], dtype=dtypes.float32, initializer=tf.ones_initializer())
                    #cell_output_.assign(array_ops.zeros([256, vocab_size]))
                    cell_output_ = tf.assign(
                        cell_output_, array_ops.zeros([256, vocab_size]))
                    indices = melt.batch_values_to_indices(
                        tf.to_int32(input_text))
                    updates = scores
                    #updates = alignments
                    cell_output_ = tf.scatter_nd_add(cell_output_, indices,
                                                     updates)
                    cell_output_ = tf.stop_gradient(cell_output_)
                    ##print(cell_output, vocab_size)
                    ##cell_output = attention
                    ##cell_output = array_ops.zeros([batch_size, vocab_size], dtype=dtypes.float32)
                    ##cell_output = scores
                    #cell_output = tf.convert_to_tensor(cell_output)

                    context_state = context_state.write(time - 1, cell_output_)
                    #context_state = context_state.write(time - 1, scores)

            # combine cell_input and attention
            next_input = array_ops.concat([cell_input, attention], 1)
            #next_input = cell_input

            return (None, cell_state, next_input, cell_output, context_state)
示例#6
0
  def take_step(self, i, prev, state):
    print('-------------i', i)
    if self.output_fn is not None:
      #[batch_size * beam_size, num_units] -> [batch_size * beam_size, num_classes]
      try:
        output = self.output_fn(prev)
      except Exception:
        output = self.output_fn(prev, state)
    else:
      output = prev

    self.output = output

    #[batch_size * beam_size, num_classes], here use log sofmax
    if self.need_softmax:
      logprobs = tf.nn.log_softmax(output)
    else:
      logprobs = tf.log(tf.maximum(output, 1e-12))
    
    if self.num_classes is None:
      self.num_classes = tf.shape(logprobs)[1]

    #->[batch_size, beam_size, num_classes]
    logprobs_batched = tf.reshape(logprobs,
                                  [-1, self.beam_size, self.num_classes])
    logprobs_batched.set_shape((None, self.beam_size, None))
    
    # Note: masking out entries to -inf plays poorly with top_k, so just subtract out a large number.
    nondone_mask = tf.reshape(
        tf.cast(
          tf.equal(tf.range(self.num_classes), self.done_token),
          tf.float32) * -1e18,
        [1, 1, self.num_classes])

    if self.past_logprobs is None:
      #[batch_size, beam_size, num_classes] -> [batch_size, num_classes]
      #-> past_logprobs[batch_size, beam_size], indices[batch_size, beam_size]
      self.past_logprobs, indices = tf.nn.top_k(
          (logprobs_batched + nondone_mask)[:, 0, :],
          self.beam_size)
      step_logprobs = self.past_logprobs
    else:
      #logprobs_batched [batch_size, beam_size, num_classes] -> [batch_size, beam_size, num_classes]  
      #past_logprobs    [batch_size, beam_size] -> [batch_size, beam_size, 1]
      step_logprobs_batched = logprobs_batched
      logprobs_batched = logprobs_batched + tf.expand_dims(self.past_logprobs, 2)


      #get [batch_size, beam_size] each
      self.past_logprobs, indices = tf.nn.top_k(
          #[batch_size, beam_size * num_classes]
          tf.reshape(logprobs_batched + nondone_mask, 
                     [-1, self.beam_size * self.num_classes]),
          self.beam_size)  

      #get current step logprobs [batch_size, beam_size]
      step_logprobs = tf.gather_nd(tf.reshape(step_logprobs_batched, 
                                              [-1, self.beam_size * self.num_classes]), 
                                   melt.batch_values_to_indices(indices))

    # For continuing to the next symbols [batch_size, beam_size]
    symbols = indices % self.num_classes
    #from wich beam it comes  [batch_size, beam_size]
    parent_refs = indices // self.num_classes
    
    if self.past_symbols is None:
      #here when i == 1, when i==0 will not do take step it just do one rnn() get output and use it for i==1 here
      #here will not need to gather state for inital state of each beam is the same
      #[batch_size, beam_size] -> [batch_size, beam_size, 1]
      self.past_symbols = tf.expand_dims(symbols, 2)
      self.past_step_logprobs = tf.expand_dims(step_logprobs, 2)
    else:
      # NOTE: outputing a zero-length sequence is not supported for simplicity reasons
      #hasky/jupter/tensorflow/beam-search2.ipynb below for mergeing path
      #here when i >= 2
      # tf.reshape(
      #           (tf.range(3 * 5) // 5) * 5,
      #           [3, 5]
      #       ).eval()
      # array([[ 0,  0,  0,  0,  0],
      #        [ 5,  5,  5,  5,  5],
      #        [10, 10, 10, 10, 10]], dtype=int32)
      parent_refs_offsets = tf.reshape(
          (tf.range(self.batch_size * self.beam_size) 
           // self.beam_size) * self.beam_size,
          [self.batch_size, self.beam_size])
      
      #self.past_symbols [batch_size, beam_size, i - 1] -> past_symbols_batch_major [batch_size * beam_size, i - 1]
      past_symbols_batch_major = tf.reshape(self.past_symbols, [-1, i-1])

      past_step_logprobs_batch_major = tf.reshape(self.past_step_logprobs, [-1, i - 1])
     
      #[batch_size, beam_size]
      past_indices = parent_refs + parent_refs_offsets 
      #-> [batch_size, beam_size, i - 1]  
      beam_past_symbols = tf.gather(past_symbols_batch_major,            #[batch_size * beam_size, i - 1]
                                    past_indices                         #[batch_size, beam_size]
                                    )

      beam_past_step_logprobs = tf.gather(past_step_logprobs_batch_major, past_indices)

      #we must also choose corresponding past state as new start
      past_indices = tf.reshape(past_indices, [-1])

      #TODO not support tf.TensorArray right now, can not use aligment_history in attention_wrapper
      def try_gather(x, indices):
        #if isinstance(x, tf.Tensor) and x.shape.ndims >= 2:
        assert isinstance(x, tf.Tensor)
        if x.shape.ndims >= 2:
          return tf.gather(x, indices)
        else:
          return x

      state = nest.map_structure(lambda x: try_gather(x, past_indices), state)

      if hasattr(state, 'alignments'):
        attention_size = melt.get_shape(state.alignments, -1)
        alignments = tf.reshape(state.alignments, [-1, self.beam_size, attention_size])
        print('alignments', alignments)

      if not self.fast_greedy:
        #[batch_size, beam_size, max_len]
        path = tf.concat([self.past_symbols, 
                          tf.ones_like(tf.expand_dims(symbols, 2)) * self.done_token,
                          tf.tile(tf.ones_like(tf.expand_dims(symbols, 2)) * self.pad_token, 
                          [1, 1, self.max_len - i])], 2)

        step_logprobs_path = tf.concat([self.past_step_logprobs, 
                                        tf.expand_dims(step_logprobs_batched[:, :, self.done_token], 2),
                                        tf.tile(tf.ones_like(tf.expand_dims(step_logprobs, 2)) * -float('inf'), 
                                                [1, 1, self.max_len - i])], 2)

        #[batch_size, 1, beam_size, max_len]
        path = tf.expand_dims(path, 1)
        step_logprobs_path = tf.expand_dims(step_logprobs_path, 1)
        self.paths_list.append(path)
        self.step_logprobs_list.append(step_logprobs_path)

      #[batch_size * beam_size, i - 1] -> [batch_size, bam_size, i] the best beam_size paths until step i
      self.past_symbols = tf.concat([beam_past_symbols, tf.expand_dims(symbols, 2)], 2)
      self.past_step_logprobs = tf.concat([beam_past_step_logprobs, tf.expand_dims(step_logprobs, 2)], 2)

      # For finishing the beam 
      #[batch_size, beam_size]
      logprobs_done = logprobs_batched[:, :, self.done_token]
      if not self.fast_greedy:
        self.logprobs_list.append(logprobs_done / i ** self.length_normalization_factor)
      else:
        done_parent_refs = tf.cast(tf.argmax(logprobs_done, 1), tf.int32)
        done_parent_refs_offsets = tf.range(self.batch_size) * self.beam_size

        done_past_symbols = tf.gather(past_symbols_batch_major,
                                      done_parent_refs + done_parent_refs_offsets)

        #[batch_size, max_len]
        symbols_done = tf.concat([done_past_symbols,
                                     tf.ones_like(done_past_symbols[:,0:1]) * self.done_token,
                                     tf.tile(tf.zeros_like(done_past_symbols[:,0:1]),
                                             [1, self.max_len - i])
                                    ], 1)

        #[batch_size, beam_size] -> [batch_size,]
        logprobs_done_max = tf.reduce_max(logprobs_done, 1)
      
        if self.length_normalization_factor > 0:
          logprobs_done_max /= i ** self.length_normalization_factor

        #[batch_size, max_len]
        self.finished_beams = tf.where(logprobs_done_max > self.logprobs_finished_beams,
                                       symbols_done,
                                       self.finished_beams)

        self.logprobs_finished_beams = tf.maximum(logprobs_done_max, self.logprobs_finished_beams)

    #->[batch_size * beam_size,]
    symbols_flat = tf.reshape(symbols, [-1])

    self.final_state = state 
    return symbols_flat, state 
示例#7
0
    def sequence_loss(self,
                      sequence,
                      initial_state=None,
                      attention_states=None,
                      input=None,
                      input_text=None,
                      exact_prob=False,
                      exact_loss=False,
                      emb=None):
        """
    for general seq2seq input is None, sequence will pad <GO>, inital_state is last state from encoder
    for img2text/showandtell input is image_embedding, inital_state is None/zero set
    TODO since exact_porb and exact_loss same value, may remove exact_prob
    NOTICE! assume sequence to be padded by zero and must have one instance full length(no zero!)
    """
        if emb is None:
            emb = self.emb

        is_training = self.is_training
        batch_size = tf.shape(sequence)[0]

        sequence, sequence_length = melt.pad(sequence,
                                             start_id=self.get_start_id(),
                                             end_id=self.get_end_id())

        #TODO different init state as show in ptb_word_lm
        state = self.cell.zero_state(
            batch_size, tf.float32) if initial_state is None else initial_state

        #[batch_size, num_steps - 1, emb_dim], remove last col
        inputs = tf.nn.embedding_lookup(emb, sequence[:, :-1])

        if is_training and FLAGS.keep_prob < 1:
            inputs = tf.nn.dropout(inputs, FLAGS.keep_prob)

        #inputs[batch_size, num_steps, emb_dim] input([batch_size, emb_dim] -> [batch_size, 1, emb_dim]) before concat
        if input is not None:
            #used like showandtell where image_emb is as input, additional to sequence
            inputs = tf.concat([tf.expand_dims(input, 1), inputs], 1)
        else:
            #common usage input is None, sequence as input, notice already pad <GO> before using melt.pad
            sequence_length -= 1
            sequence = sequence[:, 1:]

        if self.is_predict:
            #---only need when predict, since train input already dynamic length, NOTICE this will improve speed a lot
            num_steps = tf.to_int32(tf.reduce_max(sequence_length))
            sequence = sequence[:, :num_steps]
            inputs = inputs[:, :num_steps, :]

        tf.add_to_collection('sequence', sequence)
        tf.add_to_collection('sequence_length', sequence_length)

        if attention_states is None:
            outputs, state = tf.nn.dynamic_rnn(self.cell,
                                               inputs,
                                               initial_state=state,
                                               sequence_length=sequence_length,
                                               scope=self.scope)
            self.final_state = state
        else:
            attention_keys, attention_values, attention_score_fn, attention_construct_fn = \
              self.prepare_attention(attention_states)
            decoder_fn_train = melt.seq2seq.attention_decoder_fn_train(
                encoder_state=state,
                attention_keys=attention_keys,
                attention_values=attention_values,
                attention_score_fn=attention_score_fn,
                attention_construct_fn=attention_construct_fn)
            decoder_outputs_train, decoder_state_train, _ = \
                          melt.seq2seq.dynamic_rnn_decoder(
                              cell=self.cell,
                              decoder_fn=decoder_fn_train,
                              inputs=inputs,
                              sequence_length=tf.to_int32(sequence_length),
                              scope=self.scope)
            outputs = decoder_outputs_train

            self.final_state = decoder_state_train

        tf.add_to_collection('outputs', outputs)

        #[batch_size, num_steps]
        targets = sequence

        if FLAGS.copy_only:
            #TODO now not work!
            attention_scores = tf.get_collection('attention_scores')[-1]
            indices = melt.batch_values_to_indices(input_text)
            #logits = ;
        else:
            #TODO: hack here add FLAGS.predict_no_sample just for Seq2seqPredictor exact_predict
            softmax_loss_function = self.softmax_loss_function
            if self.is_predict and (exact_prob or exact_loss):
                softmax_loss_function = None

            if softmax_loss_function is None:
                #[batch_size, num_steps, num_units] * [num_units, vocab_size]
                # -> logits [batch_size, num_steps, vocab_size] (if use exact_predict_loss)
                #or [batch_size * num_steps, vocab_size] by default flatten=True
                keep_dims = exact_prob or exact_loss
                logits = melt.batch_matmul_embedding(
                    outputs, self.w, keep_dims=keep_dims) + self.v
                if not keep_dims:
                    targets = tf.reshape(targets, [-1])
            else:
                logits = outputs

            mask = tf.cast(tf.sign(targets), dtype=tf.float32)

            if self.is_predict and exact_prob:
                #generate real prob for sequence
                #for 10w vocab textsum seq2seq 20 -> 4 about
                loss = melt.seq2seq.exact_predict_loss(logits, targets, mask,
                                                       num_steps, batch_size)
            elif self.is_predict and exact_loss:
                #force no sample softmax loss, the diff with exact_prob is here we just use cross entropy error as result not real prob of seq
                #NOTICE using time a bit less  55 to 57(prob), same result with exact prob and exact score
                #but 256 vocab sample will use only about 10ms
                #TODO check more with softmax loss and sampled somtmax loss, check length normalize
                loss = melt.seq2seq.sequence_loss_by_example(logits,
                                                             targets,
                                                             weights=mask)
            else:
                #loss [batch_size,]
                loss = melt.seq2seq.sequence_loss_by_example(
                    logits,
                    targets,
                    weights=mask,
                    softmax_loss_function=softmax_loss_function)

        #mainly for compat with [bach_size, num_losses]
        loss = tf.reshape(loss, [-1, 1])

        if self.is_predict:
            loss = self.normalize_length(loss, sequence_length, exact_prob)
            #loss = tf.squeeze(loss)  TODO: later will uncomment this with all models rerun
        return loss
示例#8
0
    def sequence_loss(self,
                      sequence,
                      initial_state=None,
                      attention_states=None,
                      input=None,
                      input_text=None,
                      exact_prob=False,
                      exact_loss=False,
                      emb=None):
        """
    for general seq2seq input is None, sequence will pad <GO>, inital_state is last state from encoder
    for img2text/showandtell input is image_embedding, inital_state is None/zero set
    TODO since exact_porb and exact_loss same value, may remove exact_prob
    NOTICE! assume sequence to be padded by zero and must have one instance full length(no zero!)
    """
        if emb is None:
            emb = self.emb

        is_training = self.is_training
        batch_size = tf.shape(sequence)[0]

        sequence, sequence_length = melt.pad(sequence,
                                             start_id=self.get_start_id(),
                                             end_id=self.get_end_id())

        #[batch_size, num_steps - 1, emb_dim], remove last col
        inputs = tf.nn.embedding_lookup(emb, sequence[:, :-1])

        if is_training and FLAGS.keep_prob < 1:
            inputs = tf.nn.dropout(inputs, FLAGS.keep_prob)

        #inputs[batch_size, num_steps, emb_dim] input([batch_size, emb_dim] -> [batch_size, 1, emb_dim]) before concat
        if input is not None:
            #used like showandtell where image_emb is as input, additional to sequence
            inputs = tf.concat([tf.expand_dims(input, 1), inputs], 1)
        else:
            #common usage input is None, sequence as input, notice already pad <GO> before using melt.pad
            sequence_length -= 1
            sequence = sequence[:, 1:]

        if self.is_predict:
            #---only need when predict, since train input already dynamic length, NOTICE this will improve speed a lot
            num_steps = tf.to_int32(tf.reduce_max(sequence_length))
            sequence = sequence[:, :num_steps]
            inputs = inputs[:, :num_steps, :]

        tf.add_to_collection('sequence', sequence)
        tf.add_to_collection('sequence_length', sequence_length)

        if attention_states is None:
            cell = self.cell
        else:
            cell = self.prepare_attention(attention_states,
                                          initial_state=initial_state)
            #initial_state = None
            initial_state = cell.zero_state(batch_size, tf.float32)
        state = cell.zero_state(
            batch_size, tf.float32) if initial_state is None else initial_state

        #if attention_states is None:
        #-----TODO using attention_wrapper works now with dynamic_rnn but still slower then old attention method...
        outputs, state = tf.nn.dynamic_rnn(cell,
                                           inputs,
                                           initial_state=state,
                                           sequence_length=sequence_length,
                                           dtype=tf.float32,
                                           scope=self.scope)
        #else:
        #  #---below is also ok but slower, above 16+ ,below only 13,14 batch/s, may be due to sample id
        #  #TODO: can we make below code as fast as tf.nn.dyanmic_rnn if not need smaple id remove it ?
        #  #FIXME... AttentionWrapper is only 1/2 speed comapred to old function based attention, why?
        #  #helper = tf.contrib.seq2seq.TrainingHelper(inputs, tf.to_int32(sequence_length))
        #  helper = melt.seq2seq.TrainingHelper(inputs, tf.to_int32(sequence_length))
        #  #my_decoder = tf.contrib.seq2seq.BasicDecoder(
        #  my_decoder = melt.seq2seq.BasicTrainingDecoder(
        #      cell=cell,
        #      helper=helper,
        #      initial_state=state)
        #  outputs, state, _ = tf.contrib.seq2seq.dynamic_decode(my_decoder, scope=self.scope)
        #  #outputs = outputs.rnn_output

        self.final_state = state

        tf.add_to_collection('outputs', outputs)

        #[batch_size, num_steps]
        targets = sequence

        if FLAGS.copy_only:
            #TODO now not work!
            attention_scores = tf.get_collection('attention_scores')[-1]
            indices = melt.batch_values_to_indices(input_text)
            #logits = ;
        else:
            #TODO: hack here add FLAGS.predict_no_sample just for Seq2seqPredictor exact_predict
            softmax_loss_function = self.softmax_loss_function
            if self.is_predict and (exact_prob or exact_loss):
                softmax_loss_function = None

            if softmax_loss_function is None:
                #[batch_size, num_steps, num_units] * [num_units, vocab_size]
                # -> logits [batch_size, num_steps, vocab_size] (if use exact_predict_loss)
                #or [batch_size * num_steps, vocab_size] by default flatten=True
                keep_dims = exact_prob or exact_loss
                logits = melt.batch_matmul_embedding(
                    outputs, self.w, keep_dims=keep_dims) + self.v
                if not keep_dims:
                    targets = tf.reshape(targets, [-1])
            else:
                logits = outputs

            mask = tf.cast(tf.sign(targets), dtype=tf.float32)

            if self.is_predict and exact_prob:
                #generate real prob for sequence
                #for 10w vocab textsum seq2seq 20 -> 4 about
                loss = melt.seq2seq.exact_predict_loss(logits, targets, mask,
                                                       num_steps, batch_size)
            elif self.is_predict and exact_loss:
                #force no sample softmax loss, the diff with exact_prob is here we just use cross entropy error as result not real prob of seq
                #NOTICE using time a bit less  55 to 57(prob), same result with exact prob and exact score
                #but 256 vocab sample will use only about 10ms
                #TODO check more with softmax loss and sampled somtmax loss, check length normalize
                loss = melt.seq2seq.sequence_loss_by_example(logits,
                                                             targets,
                                                             weights=mask)
            else:
                #loss [batch_size,]
                loss = melt.seq2seq.sequence_loss_by_example(
                    logits,
                    targets,
                    weights=mask,
                    softmax_loss_function=softmax_loss_function)

        #mainly for compat with [bach_size, num_losses]
        loss = tf.reshape(loss, [-1, 1])

        if self.is_predict:
            loss = self.normalize_length(loss, sequence_length, exact_prob)
            #loss = tf.squeeze(loss)  TODO: later will uncomment this with all models rerun
        return loss
示例#9
0
    def generate_sequence_beam(self,
                               input,
                               max_words,
                               initial_state=None,
                               attention_states=None,
                               beam_size=5,
                               convert_unk=True,
                               length_normalization_factor=0.,
                               input_text=None,
                               input_text_length=None,
                               emb=None):
        """
    beam dcode means ingraph beam search
    return top (path, score)
    """
        if emb is None:
            emb = self.emb

        def loop_function(i, prev, state, decoder):
            prev, state = decoder.take_step(i, prev, state)
            next_input = tf.nn.embedding_lookup(emb, prev)
            return next_input, state

        batch_size = melt.get_batch_size(input)

        if initial_state is not None:
            initial_state = nest.map_structure(
                lambda x: tf.contrib.seq2seq.tile_batch(x, beam_size),
                initial_state)
        if attention_states is None:
            cell = self.cell
        else:
            attention_states = tf.contrib.seq2seq.tile_batch(
                attention_states, beam_size)
            #print('tiled_attention_states', attention_states, 'tiled_initial_state', initial_state)
            cell = self.prepare_attention(
                attention_states,
                initial_state=initial_state,
                score_as_alignment=self.score_as_alignment)
            initial_state = None

        state = cell.zero_state(batch_size * beam_size, tf.float32) \
                  if initial_state is None else initial_state

        if FLAGS.gen_only:
            output_fn = self.output_fn
        else:
            input_text = tf.contrib.seq2seq.tile_batch(input_text, beam_size)
            batch_size = batch_size * beam_size
            indices = melt.batch_values_to_indices(tf.to_int32(input_text))
            if FLAGS.copy_only:
                output_fn_ = self.copy_output_fn
            else:
                output_fn_ = self.gen_copy_output_fn
            output_fn = lambda cell_output, cell_state: output_fn_(
                indices, batch_size, cell_output, cell_state)

        ##TODO to be safe make topn the same as beam size
        return melt.seq2seq.beam_decode(
            input,
            max_words,
            state,
            cell,
            loop_function,
            scope=self.scope,
            beam_size=beam_size,
            done_token=vocabulary.vocab.end_id(),
            output_fn=output_fn,
            length_normalization_factor=length_normalization_factor,
            topn=beam_size)
示例#10
0
    def sequence_loss(self,
                      sequence,
                      initial_state=None,
                      attention_states=None,
                      input=None,
                      input_text=None,
                      exact_prob=False,
                      exact_loss=False,
                      emb=None):
        """
    for general seq2seq input is None, sequence will pad <GO>, inital_state is last state from encoder
    for img2text/showandtell input is image_embedding, inital_state is None/zero set
    TODO since exact_porb and exact_loss same value, may remove exact_prob
    NOTICE! assume sequence to be padded by zero and must have one instance full length(no zero!)
    """
        if emb is None:
            emb = self.emb

        is_training = self.is_training
        batch_size = melt.get_batch_size(sequence)

        sequence, sequence_length = melt.pad(sequence,
                                             start_id=self.get_start_id(),
                                             end_id=self.get_end_id())

        #[batch_size, num_steps - 1, emb_dim], remove last col
        inputs = tf.nn.embedding_lookup(emb, sequence[:, :-1])

        if is_training and FLAGS.keep_prob < 1:
            inputs = tf.nn.dropout(inputs, FLAGS.keep_prob)

        #inputs[batch_size, num_steps, emb_dim] input([batch_size, emb_dim] -> [batch_size, 1, emb_dim]) before concat
        if input is not None:
            #used like showandtell where image_emb is as input, additional to sequence
            inputs = tf.concat([tf.expand_dims(input, 1), inputs], 1)
        else:
            #common usage input is None, sequence as input, notice already pad <GO> before using melt.pad
            sequence_length -= 1
            sequence = sequence[:, 1:]

        if self.is_predict:
            #---only need when predict, since train input already dynamic length, NOTICE this will improve speed a lot
            num_steps = tf.to_int32(tf.reduce_max(sequence_length))
            sequence = sequence[:, :num_steps]
            inputs = inputs[:, :num_steps, :]

        tf.add_to_collection('sequence', sequence)
        tf.add_to_collection('sequence_length', sequence_length)

        #[batch_size, num_steps]
        targets = sequence

        if attention_states is None:
            cell = self.cell
        else:
            cell = self.prepare_attention(
                attention_states,
                initial_state=initial_state,
                score_as_alignment=self.score_as_alignment)
            initial_state = None
        state = cell.zero_state(
            batch_size, tf.float32) if initial_state is None else initial_state

        if FLAGS.gen_only:
            #gen only mode
            #for attention wrapper can not use dynamic_rnn if aligments_history=True TODO see pointer_network in application seems ok.. why
            outputs, state = tf.nn.dynamic_rnn(cell,
                                               inputs,
                                               initial_state=state,
                                               sequence_length=sequence_length,
                                               dtype=tf.float32,
                                               scope=self.scope)

            #--------below is ok but slower then dynamic_rnn 3.4batch -> 3.1 batch/s
            #helper = melt.seq2seq.TrainingHelper(inputs, tf.to_int32(sequence_length))
            ##helper = tf.contrib.seq2seq.TrainingHelper(inputs, tf.to_int32(sequence_length))
            #my_decoder = melt.seq2seq.BasicTrainingDecoder(
            ##my_decoder = tf.contrib.seq2seq.BasicDecoder(
            ##my_decoder = melt.seq2seq.BasicDecoder(
            #      cell=cell,
            #      helper=helper,
            #      initial_state=state)
            ##outputs, state, _ = tf.contrib.seq2seq.dynamic_decode(my_decoder, scope=self.scope)
            #outputs, state, _ = melt.seq2seq.dynamic_decode(my_decoder, scope=self.scope)
            ##outputs = outputs.rnn_output
        else:
            #---copy only or gen copy
            helper = melt.seq2seq.TrainingHelper(inputs,
                                                 tf.to_int32(sequence_length))

            indices = melt.batch_values_to_indices(tf.to_int32(input_text))
            if FLAGS.copy_only:
                output_fn = lambda cell_output, cell_state: self.copy_output_fn(
                    indices, batch_size, cell_output, cell_state)
            else:
                #gen_copy right now, not use switch
                sampled_values = None
                if self.softmax_loss_function is not None:
                    sampled_values = tf.nn.log_uniform_candidate_sampler(
                        true_classes=tf.reshape(targets, [-1, 1]),
                        num_true=1,
                        num_sampled=self.num_sampled,
                        unique=True,
                        range_max=self.vocab_size)
                    #TODO since perf of sampled version here is ok not modify now, but actually in addtional to sampled_values
                    #sampled_w, sampled_b can also be pre embedding lookup, may imporve not much
                output_fn = lambda time, cell_output, cell_state: self.gen_copy_output_train_fn(
                    time, indices, targets, sampled_values, batch_size,
                    cell_output, cell_state)

            my_decoder = melt.seq2seq.BasicTrainingDecoder(
                cell=cell,
                helper=helper,
                initial_state=state,
                vocab_size=self.vocab_size,
                output_fn=output_fn)
            outputs, state, _ = tf.contrib.seq2seq.dynamic_decode(
                my_decoder, scope=self.scope)
            #outputs, state, _ = melt.seq2seq.dynamic_decode(my_decoder, scope=self.scope)

        tf.add_to_collection('outputs', outputs)

        #TODO: hack here add FLAGS.predict_no_sample just for Seq2seqPredictor exact_predict
        softmax_loss_function = self.softmax_loss_function
        if self.is_predict and (exact_prob or exact_loss):
            softmax_loss_function = None

        if not FLAGS.gen_only:
            logits = outputs
            softmax_loss_function = None
        elif softmax_loss_function is not None:
            logits = outputs
        else:
            #[batch_size, num_steps, num_units] * [num_units, vocab_size]
            # -> logits [batch_size, num_steps, vocab_size] (if use exact_predict_loss)
            #or [batch_size * num_steps, vocab_size] by default flatten=True
            keep_dims = exact_prob or exact_loss
            logits = melt.batch_matmul_embedding(
                outputs, self.w, keep_dims=keep_dims) + self.v
            if not keep_dims:
                targets = tf.reshape(targets, [-1])

        tf.add_to_collection('logits', logits)

        #if input_text is not None:
        #  logits = outputs

        mask = tf.cast(tf.sign(targets), dtype=tf.float32)

        if FLAGS.gen_copy_switch:
            #TODO why need more gpu mem ? ...  do not save logits ? just calc loss in output_fn ?
            #batch size 256
            #File "/home/gezi/mine/hasky/util/melt/seq2seq/loss.py", line 154, in body
            #step_logits = logits[:, i, :]
            #ResourceExhaustedError (see above for traceback): OOM when allocating tensor with shape[256,21,33470]
            num_steps = tf.shape(targets)[1]

            loss = melt.seq2seq.exact_predict_loss(logits,
                                                   targets,
                                                   mask,
                                                   num_steps,
                                                   need_softmax=False,
                                                   need_average=True,
                                                   batch_size=batch_size)

            # loss = melt.seq2seq.sequence_loss_by_example(
            #     logits,
            #     targets,
            #     weights=mask)
        elif self.is_predict and exact_prob:
            #generate real prob for sequence
            #for 10w vocab textsum seq2seq 20 -> 4 about
            loss = melt.seq2seq.exact_predict_loss(logits,
                                                   targets,
                                                   mask,
                                                   num_steps,
                                                   batch_size=batch_size)
        elif self.is_predict and exact_loss:
            #force no sample softmax loss, the diff with exact_prob is here we just use cross entropy error as result not real prob of seq
            #NOTICE using time a bit less  55 to 57(prob), same result with exact prob and exact score
            #but 256 vocab sample will use only about 10ms
            loss = melt.seq2seq.sequence_loss_by_example(logits,
                                                         targets,
                                                         weights=mask)
        else:
            #loss [batch_size,]
            loss = melt.seq2seq.sequence_loss_by_example(
                logits,
                targets,
                weights=mask,
                softmax_loss_function=softmax_loss_function)

        #mainly for compat with [bach_size, num_losses]
        loss = tf.reshape(loss, [-1, 1])

        if self.is_predict:
            loss = self.normalize_length(loss, sequence_length, exact_prob)
            #loss = tf.squeeze(loss)  TODO: later will uncomment this with all models rerun
        return loss
示例#11
0
    def generate_sequence(self,
                          input,
                          max_words,
                          initial_state=None,
                          attention_states=None,
                          convert_unk=True,
                          input_text=None,
                          Helper=None,
                          emb=None):
        """
    this one is using greedy search method
    for beam search using generate_sequence_by_beam_search with addditional params like beam_size
    """
        if emb is None:
            emb = self.emb

        batch_size = melt.get_batch_size(input)
        if attention_states is None:
            cell = self.cell
        else:
            cell = self.prepare_attention(
                attention_states,
                initial_state=initial_state,
                score_as_alignment=self.score_as_alignment)
            initial_state = None
        state = cell.zero_state(
            batch_size, tf.float32) if initial_state is None else initial_state

        need_logprobs = FLAGS.greedy_decode_with_logprobs
        if Helper is None:
            if not need_logprobs:
                helper = melt.seq2seq.GreedyEmbeddingHelper(
                    embedding=emb, first_input=input, end_token=self.end_id)
            else:
                helper = melt.seq2seq.LogProbsGreedyEmbeddingHelper(
                    embedding=emb,
                    first_input=input,
                    end_token=self.end_id,
                    need_softmax=self.need_softmax)
        else:
            helper = melt.seq2seq.MultinomialEmbeddingHelper(
                embedding=emb,
                first_input=input,
                end_token=self.end_id,
                need_softmax=self.need_softmax)

        if FLAGS.gen_only:
            output_fn = self.output_fn
        else:
            indices = melt.batch_values_to_indices(tf.to_int32(input_text))
            if FLAGS.copy_only:
                output_fn_ = self.copy_output_fn
            else:
                output_fn_ = self.gen_copy_output_fn
            output_fn = lambda cell_output, cell_state: output_fn_(
                indices, batch_size, cell_output, cell_state)

        Decoder = melt.seq2seq.BasicDecoder if not need_logprobs else melt.seq2seq.LogProbsDecoder
        my_decoder = Decoder(cell=cell,
                             helper=helper,
                             initial_state=state,
                             vocab_size=self.vocab_size,
                             output_fn=output_fn)

        outputs, final_state, sequence_length = melt.seq2seq.dynamic_decode(
            my_decoder,
            maximum_iterations=max_words,
            #MUST set to True, other wise will not set zero and sumup tokens past done/end token
            impute_finished=True,
            scope=self.scope)
        sequence = outputs.sample_id
        if not hasattr(final_state, 'log_probs'):
            score = tf.zeros([
                batch_size,
            ])
        else:
            score = self.normalize_length(final_state.log_probs,
                                          sequence_length,
                                          reshape=False)
            ##below can be verified to be the same
            # num_steps = tf.to_int32(tf.reduce_max(sequence_length))
            # score2 = -melt.seq2seq.exact_predict_loss(outputs.rnn_output, sequence, tf.to_float(tf.sign(sequence)),
            #                                       num_steps, need_softmax=True, average_across_timesteps=False)
            # score2 = self.normalize_length(score2, sequence_length, reshape=False)
            # score -= score2

            #score = tf.exp(score)
            #score = tf.concat([tf.expand_dims(score, 1), outputs.log_probs], 1)
            if FLAGS.predict_use_prob:
                score = tf.exp(score)
            tf.add_to_collection('greedy_log_probs_list', outputs.log_probs)

        #------like beam search return sequence, score
        return sequence, score
示例#12
0
    def sequence_loss(self,
                      sequence,
                      initial_state=None,
                      attention_states=None,
                      input=None,
                      input_text=None,
                      exact_prob=False,
                      exact_loss=False,
                      emb=None):
        """
    for general seq2seq input is None, sequence will pad <GO>, inital_state is last state from encoder
    for showandtell input is image_embedding, inital_state is None/zero set, if use im2txt mode set image_as_init_state=True will do as above, need to PAD <GO> !
    TODO since exact_porb and exact_loss same value, may remove exact_prob
    NOTICE! assume sequence to be padded by zero and must have one instance full length(no zero!)
    """
        if emb is None:
            emb = self.emb

        is_training = self.is_training
        batch_size = melt.get_batch_size(sequence)

        sequence, sequence_length = melt.pad(sequence,
                                             start_id=self.get_start_id(),
                                             end_id=self.get_end_id())

        #[batch_size, num_steps - 1, emb_dim], remove last col
        inputs = tf.nn.embedding_lookup(emb, sequence[:, :-1])

        if is_training and FLAGS.keep_prob < 1:
            inputs = tf.nn.dropout(inputs, FLAGS.keep_prob)

        #inputs[batch_size, num_steps, emb_dim] input([batch_size, emb_dim] -> [batch_size, 1, emb_dim]) before concat
        if input is not None:
            #used like showandtell where image_emb is as input, additional to sequence
            inputs = tf.concat([tf.expand_dims(input, 1), inputs], 1)
        else:
            #common usage input is None, sequence as input, notice already pad <GO> before using melt.pad
            sequence_length -= 1
            sequence = sequence[:, 1:]

        if self.is_predict:
            #---only need when predict, since train input already dynamic length, NOTICE this will improve speed a lot
            num_steps = tf.to_int32(tf.reduce_max(sequence_length))
            sequence = sequence[:, :num_steps]
            inputs = inputs[:, :num_steps, :]

        tf.add_to_collection('sequence', sequence)
        tf.add_to_collection('sequence_length', sequence_length)

        #[batch_size, num_steps]
        targets = sequence

        if attention_states is None:
            cell = self.cell
        else:
            cell = self.prepare_attention(
                attention_states,
                initial_state=initial_state,
                score_as_alignment=self.score_as_alignment)
            initial_state = None
        state = cell.zero_state(
            batch_size, tf.float32) if initial_state is None else initial_state

        #TODO: hack here add FLAGS.predict_no_sample just for Seq2seqPredictor exact_predict
        softmax_loss_function = self.softmax_loss_function
        if self.is_predict and (exact_prob or exact_loss):
            softmax_loss_function = None

        scheduled_sampling_probability = FLAGS.scheduled_sampling_probability if self.is_training else 0.
        if FLAGS.gen_only:
            #gen only mode
            #for attention wrapper can not use dynamic_rnn if aligments_history=True TODO see pointer_network in application seems ok.. why
            if scheduled_sampling_probability > 0.:
                helper = melt.seq2seq.ScheduledEmbeddingTrainingHelper(
                    inputs, tf.to_int32(sequence_length), emb,
                    tf.constant(FLAGS.scheduled_sampling_probability))
                #helper = tf.contrib.seq2seq.TrainingHelper(inputs, tf.to_int32(sequence_length))
                my_decoder = melt.seq2seq.BasicDecoder(
                    #my_decoder = tf.contrib.seq2seq.BasicDecoder(
                    #my_decoder = melt.seq2seq.BasicDecoder(
                    cell=cell,
                    helper=helper,
                    initial_state=state)
                outputs, state, _ = tf.contrib.seq2seq.dynamic_decode(
                    my_decoder, scope=self.scope)
                #outputs, state, _ = melt.seq2seq.dynamic_decode(my_decoder, scope=self.scope)
                outputs = outputs.rnn_output
            else:
                outputs, state = tf.nn.dynamic_rnn(
                    cell,
                    inputs,
                    initial_state=state,
                    sequence_length=sequence_length,
                    dtype=tf.float32,
                    scope=self.scope)

            #--------below is ok but slower then dynamic_rnn 3.4batch -> 3.1 batch/s
            #helper = melt.seq2seq.TrainingHelper(inputs, tf.to_int32(sequence_length))
            ##helper = tf.contrib.seq2seq.TrainingHelper(inputs, tf.to_int32(sequence_length))
            #my_decoder = melt.seq2seq.BasicTrainingDecoder(
            ##my_decoder = tf.contrib.seq2seq.BasicDecoder(
            ##my_decoder = melt.seq2seq.BasicDecoder(
            #      cell=cell,
            #      helper=helper,
            #      initial_state=state)
            ##outputs, state, _ = tf.contrib.seq2seq.dynamic_decode(my_decoder, scope=self.scope)
            #outputs, state, _ = melt.seq2seq.dynamic_decode(my_decoder, scope=self.scope)
            ##outputs = outputs.rnn_output
        else:
            #---copy only or gen copy
            if scheduled_sampling_probability > 0.:
                #not tested yet TODO
                helper = melt.seq2seq.ScheduledEmbeddingTrainingHelper(
                    inputs, tf.to_int32(sequence_length), emb,
                    tf.constant(FLAGS.scheduled_sampling_probability))
                Decoder_ = melt.seq2seq.BasicDecoder
            else:
                #as before
                helper = melt.seq2seq.TrainingHelper(
                    inputs, tf.to_int32(sequence_length))
                Decoder_ = melt.seq2seq.BasicTrainingDecoder

            indices = melt.batch_values_to_indices(tf.to_int32(input_text))
            if FLAGS.copy_only:
                output_fn = lambda cell_output, cell_state: self.copy_output_fn(
                    indices, batch_size, cell_output, cell_state)
            else:
                #gen_copy right now, not use switch ? gen_copy and switch?
                sampled_values = None
                #TODO CHECK this is it ok? why train and predict not equal and score/exact score same? FIXME
                #need first debug why score and exact score is same ? score should be the same as train! TODO
                #sh ./inference/infrence-score.sh to reproduce
                #now just set num_sampled = 0 for safe, may be here train also not correct FIXME
                if softmax_loss_function is not None:
                    sampled_values = tf.nn.log_uniform_candidate_sampler(
                        true_classes=tf.reshape(targets, [-1, 1]),
                        num_true=1,
                        num_sampled=self.num_sampled,
                        unique=True,
                        range_max=self.vocab_size)
                    #TODO since perf of sampled version here is ok not modify now, but actually in addtional to sampled_values
                    #sampled_w, sampled_b can also be pre embedding lookup, may imporve not much
                output_fn = lambda time, cell_output, cell_state: self.gen_copy_output_train_fn(
                    time, indices, targets, sampled_values, batch_size,
                    cell_output, cell_state)

            my_decoder = Decoder_(cell=cell,
                                  helper=helper,
                                  initial_state=state,
                                  vocab_size=self.vocab_size,
                                  output_fn=output_fn)
            outputs, state, _ = tf.contrib.seq2seq.dynamic_decode(
                my_decoder, scope=self.scope)
            #outputs, state, _ = melt.seq2seq.dynamic_decode(my_decoder, scope=self.scope)
            if hasattr(outputs, 'rnn_output'):
                outputs = outputs.rnn_output

        tf.add_to_collection('outputs', outputs)

        if not FLAGS.gen_only:
            logits = outputs
            softmax_loss_function = None
        elif softmax_loss_function is not None:
            logits = outputs
        else:
            #--softmax_loss_function is None means num_sample = 0 or exact_loss or exact_prob
            #[batch_size, num_steps, num_units] * [num_units, vocab_size]
            # -> logits [batch_size, num_steps, vocab_size] (if use exact_predict_loss)
            #or [batch_size * num_steps, vocab_size] by default flatten=True
            #this will be fine for train [batch_size * num_steps] but not good for eval since we want
            #get score of each instance also not good for predict
            #--------only training mode not keep dims, but this will be dangerous, since class call rnn_decoder
            #need to manully set rnn_decoder.is_training=False!  TODO other wise will show incorrect scores in eval mode
            #but not affect the final model!
            keep_dims = exact_prob or exact_loss or (not self.is_training)
            logits = melt.batch_matmul_embedding(
                outputs, self.w, keep_dims=keep_dims) + self.v
            if not keep_dims:
                targets = tf.reshape(targets, [-1])

        tf.add_to_collection('logits', logits)

        mask = tf.cast(tf.sign(targets), dtype=tf.float32)

        if FLAGS.gen_copy_switch and FLAGS.switch_after_softmax:
            #TODO why need more gpu mem ? ...  do not save logits ? just calc loss in output_fn ?
            #batch size 256
            #File "/home/gezi/mine/hasky/util/melt/seq2seq/loss.py", line 154, in body
            #step_logits = logits[:, i, :]
            #ResourceExhaustedError (see above for traceback): OOM when allocating tensor with shape[256,21,33470]
            num_steps = tf.shape(targets)[1]

            loss = melt.seq2seq.exact_predict_loss(
                logits,
                targets,
                mask,
                num_steps,
                need_softmax=False,
                average_across_timesteps=not self.is_predict,
                batch_size=batch_size)
        elif self.is_predict and exact_prob:
            #generate real prob for sequence
            #for 10w vocab textsum seq2seq 20 -> 4 about
            loss = melt.seq2seq.exact_predict_loss(
                logits,
                targets,
                mask,
                num_steps,
                batch_size=batch_size,
                average_across_timesteps=False)
        elif self.is_predict and exact_loss:
            #force no sample softmax loss, the diff with exact_prob is here we just use cross entropy error as result not real prob of seq
            #NOTICE using time a bit less  55 to 57(prob), same result with exact prob and exact score
            #but 256 vocab sample will use only about 10ms
            loss = melt.seq2seq.sequence_loss_by_example(
                logits, targets, weights=mask, average_across_timesteps=False)
        else:
            #loss [batch_size,]
            loss = melt.seq2seq.sequence_loss_by_example(
                logits,
                targets,
                weights=mask,
                average_across_timesteps=not self.
                is_predict,  #train must average, other wise long sentence big loss..
                softmax_loss_function=softmax_loss_function)

        #mainly for compat with [bach_size, num_losses] here may be [batch_size * num_steps,] if is_training and not exact loss/prob
        loss = tf.reshape(loss, [-1, 1])

        self.ori_loss = loss
        if self.is_predict:
            #note use avg_loss not to change loss pointer, avg_loss is same as average time step=True is length_normalize_fator=1.0
            avg_loss = self.normalize_length(loss, sequence_length)
            return avg_loss

        #if not is_predict loss is averaged per time step else not but avg loss will average it
        return loss