示例#1
0
    def sequence_loss(self,
                      sequence,
                      initial_state=None,
                      attention_states=None,
                      input=None,
                      input_text=None,
                      exact_prob=False,
                      exact_loss=False,
                      emb=None):
        """
    for general seq2seq input is None, sequence will pad <GO>, inital_state is last state from encoder
    for img2text/showandtell input is image_embedding, inital_state is None/zero set
    TODO since exact_porb and exact_loss same value, may remove exact_prob
    NOTICE! assume sequence to be padded by zero and must have one instance full length(no zero!)
    """
        if emb is None:
            emb = self.emb

        is_training = self.is_training
        batch_size = tf.shape(sequence)[0]

        sequence, sequence_length = melt.pad(sequence,
                                             start_id=self.get_start_id(),
                                             end_id=self.get_end_id())

        #TODO different init state as show in ptb_word_lm
        state = self.cell.zero_state(
            batch_size, tf.float32) if initial_state is None else initial_state

        #[batch_size, num_steps - 1, emb_dim], remove last col
        inputs = tf.nn.embedding_lookup(emb, sequence[:, :-1])

        if is_training and FLAGS.keep_prob < 1:
            inputs = tf.nn.dropout(inputs, FLAGS.keep_prob)

        #inputs[batch_size, num_steps, emb_dim] input([batch_size, emb_dim] -> [batch_size, 1, emb_dim]) before concat
        if input is not None:
            #used like showandtell where image_emb is as input, additional to sequence
            inputs = tf.concat([tf.expand_dims(input, 1), inputs], 1)
        else:
            #common usage input is None, sequence as input, notice already pad <GO> before using melt.pad
            sequence_length -= 1
            sequence = sequence[:, 1:]

        if self.is_predict:
            #---only need when predict, since train input already dynamic length, NOTICE this will improve speed a lot
            num_steps = tf.to_int32(tf.reduce_max(sequence_length))
            sequence = sequence[:, :num_steps]
            inputs = inputs[:, :num_steps, :]

        tf.add_to_collection('sequence', sequence)
        tf.add_to_collection('sequence_length', sequence_length)

        if attention_states is None:
            outputs, state = tf.nn.dynamic_rnn(self.cell,
                                               inputs,
                                               initial_state=state,
                                               sequence_length=sequence_length,
                                               scope=self.scope)
            self.final_state = state
        else:
            attention_keys, attention_values, attention_score_fn, attention_construct_fn = \
              self.prepare_attention(attention_states)
            decoder_fn_train = melt.seq2seq.attention_decoder_fn_train(
                encoder_state=state,
                attention_keys=attention_keys,
                attention_values=attention_values,
                attention_score_fn=attention_score_fn,
                attention_construct_fn=attention_construct_fn)
            decoder_outputs_train, decoder_state_train, _ = \
                          melt.seq2seq.dynamic_rnn_decoder(
                              cell=self.cell,
                              decoder_fn=decoder_fn_train,
                              inputs=inputs,
                              sequence_length=tf.to_int32(sequence_length),
                              scope=self.scope)
            outputs = decoder_outputs_train

            self.final_state = decoder_state_train

        tf.add_to_collection('outputs', outputs)

        #[batch_size, num_steps]
        targets = sequence

        if FLAGS.copy_only:
            #TODO now not work!
            attention_scores = tf.get_collection('attention_scores')[-1]
            indices = melt.batch_values_to_indices(input_text)
            #logits = ;
        else:
            #TODO: hack here add FLAGS.predict_no_sample just for Seq2seqPredictor exact_predict
            softmax_loss_function = self.softmax_loss_function
            if self.is_predict and (exact_prob or exact_loss):
                softmax_loss_function = None

            if softmax_loss_function is None:
                #[batch_size, num_steps, num_units] * [num_units, vocab_size]
                # -> logits [batch_size, num_steps, vocab_size] (if use exact_predict_loss)
                #or [batch_size * num_steps, vocab_size] by default flatten=True
                keep_dims = exact_prob or exact_loss
                logits = melt.batch_matmul_embedding(
                    outputs, self.w, keep_dims=keep_dims) + self.v
                if not keep_dims:
                    targets = tf.reshape(targets, [-1])
            else:
                logits = outputs

            mask = tf.cast(tf.sign(targets), dtype=tf.float32)

            if self.is_predict and exact_prob:
                #generate real prob for sequence
                #for 10w vocab textsum seq2seq 20 -> 4 about
                loss = melt.seq2seq.exact_predict_loss(logits, targets, mask,
                                                       num_steps, batch_size)
            elif self.is_predict and exact_loss:
                #force no sample softmax loss, the diff with exact_prob is here we just use cross entropy error as result not real prob of seq
                #NOTICE using time a bit less  55 to 57(prob), same result with exact prob and exact score
                #but 256 vocab sample will use only about 10ms
                #TODO check more with softmax loss and sampled somtmax loss, check length normalize
                loss = melt.seq2seq.sequence_loss_by_example(logits,
                                                             targets,
                                                             weights=mask)
            else:
                #loss [batch_size,]
                loss = melt.seq2seq.sequence_loss_by_example(
                    logits,
                    targets,
                    weights=mask,
                    softmax_loss_function=softmax_loss_function)

        #mainly for compat with [bach_size, num_losses]
        loss = tf.reshape(loss, [-1, 1])

        if self.is_predict:
            loss = self.normalize_length(loss, sequence_length, exact_prob)
            #loss = tf.squeeze(loss)  TODO: later will uncomment this with all models rerun
        return loss
示例#2
0
    def sequence_loss(self,
                      sequence,
                      initial_state=None,
                      attention_states=None,
                      input=None,
                      input_text=None,
                      exact_prob=False,
                      exact_loss=False,
                      emb=None):
        """
    for general seq2seq input is None, sequence will pad <GO>, inital_state is last state from encoder
    for img2text/showandtell input is image_embedding, inital_state is None/zero set
    TODO since exact_porb and exact_loss same value, may remove exact_prob
    NOTICE! assume sequence to be padded by zero and must have one instance full length(no zero!)
    """
        if emb is None:
            emb = self.emb

        is_training = self.is_training
        batch_size = tf.shape(sequence)[0]

        sequence, sequence_length = melt.pad(sequence,
                                             start_id=self.get_start_id(),
                                             end_id=self.get_end_id())

        #[batch_size, num_steps - 1, emb_dim], remove last col
        inputs = tf.nn.embedding_lookup(emb, sequence[:, :-1])

        if is_training and FLAGS.keep_prob < 1:
            inputs = tf.nn.dropout(inputs, FLAGS.keep_prob)

        #inputs[batch_size, num_steps, emb_dim] input([batch_size, emb_dim] -> [batch_size, 1, emb_dim]) before concat
        if input is not None:
            #used like showandtell where image_emb is as input, additional to sequence
            inputs = tf.concat([tf.expand_dims(input, 1), inputs], 1)
        else:
            #common usage input is None, sequence as input, notice already pad <GO> before using melt.pad
            sequence_length -= 1
            sequence = sequence[:, 1:]

        if self.is_predict:
            #---only need when predict, since train input already dynamic length, NOTICE this will improve speed a lot
            num_steps = tf.to_int32(tf.reduce_max(sequence_length))
            sequence = sequence[:, :num_steps]
            inputs = inputs[:, :num_steps, :]

        tf.add_to_collection('sequence', sequence)
        tf.add_to_collection('sequence_length', sequence_length)

        if attention_states is None:
            cell = self.cell
        else:
            cell = self.prepare_attention(attention_states,
                                          initial_state=initial_state)
            #initial_state = None
            initial_state = cell.zero_state(batch_size, tf.float32)
        state = cell.zero_state(
            batch_size, tf.float32) if initial_state is None else initial_state

        #if attention_states is None:
        #-----TODO using attention_wrapper works now with dynamic_rnn but still slower then old attention method...
        outputs, state = tf.nn.dynamic_rnn(cell,
                                           inputs,
                                           initial_state=state,
                                           sequence_length=sequence_length,
                                           dtype=tf.float32,
                                           scope=self.scope)
        #else:
        #  #---below is also ok but slower, above 16+ ,below only 13,14 batch/s, may be due to sample id
        #  #TODO: can we make below code as fast as tf.nn.dyanmic_rnn if not need smaple id remove it ?
        #  #FIXME... AttentionWrapper is only 1/2 speed comapred to old function based attention, why?
        #  #helper = tf.contrib.seq2seq.TrainingHelper(inputs, tf.to_int32(sequence_length))
        #  helper = melt.seq2seq.TrainingHelper(inputs, tf.to_int32(sequence_length))
        #  #my_decoder = tf.contrib.seq2seq.BasicDecoder(
        #  my_decoder = melt.seq2seq.BasicTrainingDecoder(
        #      cell=cell,
        #      helper=helper,
        #      initial_state=state)
        #  outputs, state, _ = tf.contrib.seq2seq.dynamic_decode(my_decoder, scope=self.scope)
        #  #outputs = outputs.rnn_output

        self.final_state = state

        tf.add_to_collection('outputs', outputs)

        #[batch_size, num_steps]
        targets = sequence

        if FLAGS.copy_only:
            #TODO now not work!
            attention_scores = tf.get_collection('attention_scores')[-1]
            indices = melt.batch_values_to_indices(input_text)
            #logits = ;
        else:
            #TODO: hack here add FLAGS.predict_no_sample just for Seq2seqPredictor exact_predict
            softmax_loss_function = self.softmax_loss_function
            if self.is_predict and (exact_prob or exact_loss):
                softmax_loss_function = None

            if softmax_loss_function is None:
                #[batch_size, num_steps, num_units] * [num_units, vocab_size]
                # -> logits [batch_size, num_steps, vocab_size] (if use exact_predict_loss)
                #or [batch_size * num_steps, vocab_size] by default flatten=True
                keep_dims = exact_prob or exact_loss
                logits = melt.batch_matmul_embedding(
                    outputs, self.w, keep_dims=keep_dims) + self.v
                if not keep_dims:
                    targets = tf.reshape(targets, [-1])
            else:
                logits = outputs

            mask = tf.cast(tf.sign(targets), dtype=tf.float32)

            if self.is_predict and exact_prob:
                #generate real prob for sequence
                #for 10w vocab textsum seq2seq 20 -> 4 about
                loss = melt.seq2seq.exact_predict_loss(logits, targets, mask,
                                                       num_steps, batch_size)
            elif self.is_predict and exact_loss:
                #force no sample softmax loss, the diff with exact_prob is here we just use cross entropy error as result not real prob of seq
                #NOTICE using time a bit less  55 to 57(prob), same result with exact prob and exact score
                #but 256 vocab sample will use only about 10ms
                #TODO check more with softmax loss and sampled somtmax loss, check length normalize
                loss = melt.seq2seq.sequence_loss_by_example(logits,
                                                             targets,
                                                             weights=mask)
            else:
                #loss [batch_size,]
                loss = melt.seq2seq.sequence_loss_by_example(
                    logits,
                    targets,
                    weights=mask,
                    softmax_loss_function=softmax_loss_function)

        #mainly for compat with [bach_size, num_losses]
        loss = tf.reshape(loss, [-1, 1])

        if self.is_predict:
            loss = self.normalize_length(loss, sequence_length, exact_prob)
            #loss = tf.squeeze(loss)  TODO: later will uncomment this with all models rerun
        return loss
示例#3
0
    def sequence_loss(self,
                      sequence,
                      initial_state=None,
                      attention_states=None,
                      input=None,
                      input_text=None,
                      exact_prob=False,
                      exact_loss=False,
                      emb=None):
        """
    for general seq2seq input is None, sequence will pad <GO>, inital_state is last state from encoder
    for showandtell input is image_embedding, inital_state is None/zero set, if use im2txt mode set image_as_init_state=True will do as above, need to PAD <GO> !
    TODO since exact_porb and exact_loss same value, may remove exact_prob
    NOTICE! assume sequence to be padded by zero and must have one instance full length(no zero!)
    """
        if emb is None:
            emb = self.emb

        is_training = self.is_training
        batch_size = melt.get_batch_size(sequence)

        sequence, sequence_length = melt.pad(sequence,
                                             start_id=self.get_start_id(),
                                             end_id=self.get_end_id())

        #[batch_size, num_steps - 1, emb_dim], remove last col
        inputs = tf.nn.embedding_lookup(emb, sequence[:, :-1])

        if is_training and FLAGS.keep_prob < 1:
            inputs = tf.nn.dropout(inputs, FLAGS.keep_prob)

        #inputs[batch_size, num_steps, emb_dim] input([batch_size, emb_dim] -> [batch_size, 1, emb_dim]) before concat
        if input is not None:
            #used like showandtell where image_emb is as input, additional to sequence
            inputs = tf.concat([tf.expand_dims(input, 1), inputs], 1)
        else:
            #common usage input is None, sequence as input, notice already pad <GO> before using melt.pad
            sequence_length -= 1
            sequence = sequence[:, 1:]

        if self.is_predict:
            #---only need when predict, since train input already dynamic length, NOTICE this will improve speed a lot
            num_steps = tf.to_int32(tf.reduce_max(sequence_length))
            sequence = sequence[:, :num_steps]
            inputs = inputs[:, :num_steps, :]

        tf.add_to_collection('sequence', sequence)
        tf.add_to_collection('sequence_length', sequence_length)

        #[batch_size, num_steps]
        targets = sequence

        if attention_states is None:
            cell = self.cell
        else:
            cell = self.prepare_attention(
                attention_states,
                initial_state=initial_state,
                score_as_alignment=self.score_as_alignment)
            initial_state = None
        state = cell.zero_state(
            batch_size, tf.float32) if initial_state is None else initial_state

        #TODO: hack here add FLAGS.predict_no_sample just for Seq2seqPredictor exact_predict
        softmax_loss_function = self.softmax_loss_function
        if self.is_predict and (exact_prob or exact_loss):
            softmax_loss_function = None

        scheduled_sampling_probability = FLAGS.scheduled_sampling_probability if self.is_training else 0.
        if FLAGS.gen_only:
            #gen only mode
            #for attention wrapper can not use dynamic_rnn if aligments_history=True TODO see pointer_network in application seems ok.. why
            if scheduled_sampling_probability > 0.:
                helper = melt.seq2seq.ScheduledEmbeddingTrainingHelper(
                    inputs, tf.to_int32(sequence_length), emb,
                    tf.constant(FLAGS.scheduled_sampling_probability))
                #helper = tf.contrib.seq2seq.TrainingHelper(inputs, tf.to_int32(sequence_length))
                my_decoder = melt.seq2seq.BasicDecoder(
                    #my_decoder = tf.contrib.seq2seq.BasicDecoder(
                    #my_decoder = melt.seq2seq.BasicDecoder(
                    cell=cell,
                    helper=helper,
                    initial_state=state)
                outputs, state, _ = tf.contrib.seq2seq.dynamic_decode(
                    my_decoder, scope=self.scope)
                #outputs, state, _ = melt.seq2seq.dynamic_decode(my_decoder, scope=self.scope)
                outputs = outputs.rnn_output
            else:
                outputs, state = tf.nn.dynamic_rnn(
                    cell,
                    inputs,
                    initial_state=state,
                    sequence_length=sequence_length,
                    dtype=tf.float32,
                    scope=self.scope)

            #--------below is ok but slower then dynamic_rnn 3.4batch -> 3.1 batch/s
            #helper = melt.seq2seq.TrainingHelper(inputs, tf.to_int32(sequence_length))
            ##helper = tf.contrib.seq2seq.TrainingHelper(inputs, tf.to_int32(sequence_length))
            #my_decoder = melt.seq2seq.BasicTrainingDecoder(
            ##my_decoder = tf.contrib.seq2seq.BasicDecoder(
            ##my_decoder = melt.seq2seq.BasicDecoder(
            #      cell=cell,
            #      helper=helper,
            #      initial_state=state)
            ##outputs, state, _ = tf.contrib.seq2seq.dynamic_decode(my_decoder, scope=self.scope)
            #outputs, state, _ = melt.seq2seq.dynamic_decode(my_decoder, scope=self.scope)
            ##outputs = outputs.rnn_output
        else:
            #---copy only or gen copy
            if scheduled_sampling_probability > 0.:
                #not tested yet TODO
                helper = melt.seq2seq.ScheduledEmbeddingTrainingHelper(
                    inputs, tf.to_int32(sequence_length), emb,
                    tf.constant(FLAGS.scheduled_sampling_probability))
                Decoder_ = melt.seq2seq.BasicDecoder
            else:
                #as before
                helper = melt.seq2seq.TrainingHelper(
                    inputs, tf.to_int32(sequence_length))
                Decoder_ = melt.seq2seq.BasicTrainingDecoder

            indices = melt.batch_values_to_indices(tf.to_int32(input_text))
            if FLAGS.copy_only:
                output_fn = lambda cell_output, cell_state: self.copy_output_fn(
                    indices, batch_size, cell_output, cell_state)
            else:
                #gen_copy right now, not use switch ? gen_copy and switch?
                sampled_values = None
                #TODO CHECK this is it ok? why train and predict not equal and score/exact score same? FIXME
                #need first debug why score and exact score is same ? score should be the same as train! TODO
                #sh ./inference/infrence-score.sh to reproduce
                #now just set num_sampled = 0 for safe, may be here train also not correct FIXME
                if softmax_loss_function is not None:
                    sampled_values = tf.nn.log_uniform_candidate_sampler(
                        true_classes=tf.reshape(targets, [-1, 1]),
                        num_true=1,
                        num_sampled=self.num_sampled,
                        unique=True,
                        range_max=self.vocab_size)
                    #TODO since perf of sampled version here is ok not modify now, but actually in addtional to sampled_values
                    #sampled_w, sampled_b can also be pre embedding lookup, may imporve not much
                output_fn = lambda time, cell_output, cell_state: self.gen_copy_output_train_fn(
                    time, indices, targets, sampled_values, batch_size,
                    cell_output, cell_state)

            my_decoder = Decoder_(cell=cell,
                                  helper=helper,
                                  initial_state=state,
                                  vocab_size=self.vocab_size,
                                  output_fn=output_fn)
            outputs, state, _ = tf.contrib.seq2seq.dynamic_decode(
                my_decoder, scope=self.scope)
            #outputs, state, _ = melt.seq2seq.dynamic_decode(my_decoder, scope=self.scope)
            if hasattr(outputs, 'rnn_output'):
                outputs = outputs.rnn_output

        tf.add_to_collection('outputs', outputs)

        if not FLAGS.gen_only:
            logits = outputs
            softmax_loss_function = None
        elif softmax_loss_function is not None:
            logits = outputs
        else:
            #--softmax_loss_function is None means num_sample = 0 or exact_loss or exact_prob
            #[batch_size, num_steps, num_units] * [num_units, vocab_size]
            # -> logits [batch_size, num_steps, vocab_size] (if use exact_predict_loss)
            #or [batch_size * num_steps, vocab_size] by default flatten=True
            #this will be fine for train [batch_size * num_steps] but not good for eval since we want
            #get score of each instance also not good for predict
            #--------only training mode not keep dims, but this will be dangerous, since class call rnn_decoder
            #need to manully set rnn_decoder.is_training=False!  TODO other wise will show incorrect scores in eval mode
            #but not affect the final model!
            keep_dims = exact_prob or exact_loss or (not self.is_training)
            logits = melt.batch_matmul_embedding(
                outputs, self.w, keep_dims=keep_dims) + self.v
            if not keep_dims:
                targets = tf.reshape(targets, [-1])

        tf.add_to_collection('logits', logits)

        mask = tf.cast(tf.sign(targets), dtype=tf.float32)

        if FLAGS.gen_copy_switch and FLAGS.switch_after_softmax:
            #TODO why need more gpu mem ? ...  do not save logits ? just calc loss in output_fn ?
            #batch size 256
            #File "/home/gezi/mine/hasky/util/melt/seq2seq/loss.py", line 154, in body
            #step_logits = logits[:, i, :]
            #ResourceExhaustedError (see above for traceback): OOM when allocating tensor with shape[256,21,33470]
            num_steps = tf.shape(targets)[1]

            loss = melt.seq2seq.exact_predict_loss(
                logits,
                targets,
                mask,
                num_steps,
                need_softmax=False,
                average_across_timesteps=not self.is_predict,
                batch_size=batch_size)
        elif self.is_predict and exact_prob:
            #generate real prob for sequence
            #for 10w vocab textsum seq2seq 20 -> 4 about
            loss = melt.seq2seq.exact_predict_loss(
                logits,
                targets,
                mask,
                num_steps,
                batch_size=batch_size,
                average_across_timesteps=False)
        elif self.is_predict and exact_loss:
            #force no sample softmax loss, the diff with exact_prob is here we just use cross entropy error as result not real prob of seq
            #NOTICE using time a bit less  55 to 57(prob), same result with exact prob and exact score
            #but 256 vocab sample will use only about 10ms
            loss = melt.seq2seq.sequence_loss_by_example(
                logits, targets, weights=mask, average_across_timesteps=False)
        else:
            #loss [batch_size,]
            loss = melt.seq2seq.sequence_loss_by_example(
                logits,
                targets,
                weights=mask,
                average_across_timesteps=not self.
                is_predict,  #train must average, other wise long sentence big loss..
                softmax_loss_function=softmax_loss_function)

        #mainly for compat with [bach_size, num_losses] here may be [batch_size * num_steps,] if is_training and not exact loss/prob
        loss = tf.reshape(loss, [-1, 1])

        self.ori_loss = loss
        if self.is_predict:
            #note use avg_loss not to change loss pointer, avg_loss is same as average time step=True is length_normalize_fator=1.0
            avg_loss = self.normalize_length(loss, sequence_length)
            return avg_loss

        #if not is_predict loss is averaged per time step else not but avg loss will average it
        return loss
示例#4
0
    def sequence_loss(self,
                      sequence,
                      initial_state=None,
                      attention_states=None,
                      input=None,
                      input_text=None,
                      exact_prob=False,
                      exact_loss=False,
                      emb=None):
        """
    for general seq2seq input is None, sequence will pad <GO>, inital_state is last state from encoder
    for img2text/showandtell input is image_embedding, inital_state is None/zero set
    TODO since exact_porb and exact_loss same value, may remove exact_prob
    NOTICE! assume sequence to be padded by zero and must have one instance full length(no zero!)
    """
        if emb is None:
            emb = self.emb

        is_training = self.is_training
        batch_size = melt.get_batch_size(sequence)

        sequence, sequence_length = melt.pad(sequence,
                                             start_id=self.get_start_id(),
                                             end_id=self.get_end_id())

        #[batch_size, num_steps - 1, emb_dim], remove last col
        inputs = tf.nn.embedding_lookup(emb, sequence[:, :-1])

        if is_training and FLAGS.keep_prob < 1:
            inputs = tf.nn.dropout(inputs, FLAGS.keep_prob)

        #inputs[batch_size, num_steps, emb_dim] input([batch_size, emb_dim] -> [batch_size, 1, emb_dim]) before concat
        if input is not None:
            #used like showandtell where image_emb is as input, additional to sequence
            inputs = tf.concat([tf.expand_dims(input, 1), inputs], 1)
        else:
            #common usage input is None, sequence as input, notice already pad <GO> before using melt.pad
            sequence_length -= 1
            sequence = sequence[:, 1:]

        if self.is_predict:
            #---only need when predict, since train input already dynamic length, NOTICE this will improve speed a lot
            num_steps = tf.to_int32(tf.reduce_max(sequence_length))
            sequence = sequence[:, :num_steps]
            inputs = inputs[:, :num_steps, :]

        tf.add_to_collection('sequence', sequence)
        tf.add_to_collection('sequence_length', sequence_length)

        #[batch_size, num_steps]
        targets = sequence

        if attention_states is None:
            cell = self.cell
        else:
            cell = self.prepare_attention(
                attention_states,
                initial_state=initial_state,
                score_as_alignment=self.score_as_alignment)
            initial_state = None
        state = cell.zero_state(
            batch_size, tf.float32) if initial_state is None else initial_state

        if FLAGS.gen_only:
            #gen only mode
            #for attention wrapper can not use dynamic_rnn if aligments_history=True TODO see pointer_network in application seems ok.. why
            outputs, state = tf.nn.dynamic_rnn(cell,
                                               inputs,
                                               initial_state=state,
                                               sequence_length=sequence_length,
                                               dtype=tf.float32,
                                               scope=self.scope)

            #--------below is ok but slower then dynamic_rnn 3.4batch -> 3.1 batch/s
            #helper = melt.seq2seq.TrainingHelper(inputs, tf.to_int32(sequence_length))
            ##helper = tf.contrib.seq2seq.TrainingHelper(inputs, tf.to_int32(sequence_length))
            #my_decoder = melt.seq2seq.BasicTrainingDecoder(
            ##my_decoder = tf.contrib.seq2seq.BasicDecoder(
            ##my_decoder = melt.seq2seq.BasicDecoder(
            #      cell=cell,
            #      helper=helper,
            #      initial_state=state)
            ##outputs, state, _ = tf.contrib.seq2seq.dynamic_decode(my_decoder, scope=self.scope)
            #outputs, state, _ = melt.seq2seq.dynamic_decode(my_decoder, scope=self.scope)
            ##outputs = outputs.rnn_output
        else:
            #---copy only or gen copy
            helper = melt.seq2seq.TrainingHelper(inputs,
                                                 tf.to_int32(sequence_length))

            indices = melt.batch_values_to_indices(tf.to_int32(input_text))
            if FLAGS.copy_only:
                output_fn = lambda cell_output, cell_state: self.copy_output_fn(
                    indices, batch_size, cell_output, cell_state)
            else:
                #gen_copy right now, not use switch
                sampled_values = None
                if self.softmax_loss_function is not None:
                    sampled_values = tf.nn.log_uniform_candidate_sampler(
                        true_classes=tf.reshape(targets, [-1, 1]),
                        num_true=1,
                        num_sampled=self.num_sampled,
                        unique=True,
                        range_max=self.vocab_size)
                    #TODO since perf of sampled version here is ok not modify now, but actually in addtional to sampled_values
                    #sampled_w, sampled_b can also be pre embedding lookup, may imporve not much
                output_fn = lambda time, cell_output, cell_state: self.gen_copy_output_train_fn(
                    time, indices, targets, sampled_values, batch_size,
                    cell_output, cell_state)

            my_decoder = melt.seq2seq.BasicTrainingDecoder(
                cell=cell,
                helper=helper,
                initial_state=state,
                vocab_size=self.vocab_size,
                output_fn=output_fn)
            outputs, state, _ = tf.contrib.seq2seq.dynamic_decode(
                my_decoder, scope=self.scope)
            #outputs, state, _ = melt.seq2seq.dynamic_decode(my_decoder, scope=self.scope)

        tf.add_to_collection('outputs', outputs)

        #TODO: hack here add FLAGS.predict_no_sample just for Seq2seqPredictor exact_predict
        softmax_loss_function = self.softmax_loss_function
        if self.is_predict and (exact_prob or exact_loss):
            softmax_loss_function = None

        if not FLAGS.gen_only:
            logits = outputs
            softmax_loss_function = None
        elif softmax_loss_function is not None:
            logits = outputs
        else:
            #[batch_size, num_steps, num_units] * [num_units, vocab_size]
            # -> logits [batch_size, num_steps, vocab_size] (if use exact_predict_loss)
            #or [batch_size * num_steps, vocab_size] by default flatten=True
            keep_dims = exact_prob or exact_loss
            logits = melt.batch_matmul_embedding(
                outputs, self.w, keep_dims=keep_dims) + self.v
            if not keep_dims:
                targets = tf.reshape(targets, [-1])

        tf.add_to_collection('logits', logits)

        #if input_text is not None:
        #  logits = outputs

        mask = tf.cast(tf.sign(targets), dtype=tf.float32)

        if FLAGS.gen_copy_switch:
            #TODO why need more gpu mem ? ...  do not save logits ? just calc loss in output_fn ?
            #batch size 256
            #File "/home/gezi/mine/hasky/util/melt/seq2seq/loss.py", line 154, in body
            #step_logits = logits[:, i, :]
            #ResourceExhaustedError (see above for traceback): OOM when allocating tensor with shape[256,21,33470]
            num_steps = tf.shape(targets)[1]

            loss = melt.seq2seq.exact_predict_loss(logits,
                                                   targets,
                                                   mask,
                                                   num_steps,
                                                   need_softmax=False,
                                                   need_average=True,
                                                   batch_size=batch_size)

            # loss = melt.seq2seq.sequence_loss_by_example(
            #     logits,
            #     targets,
            #     weights=mask)
        elif self.is_predict and exact_prob:
            #generate real prob for sequence
            #for 10w vocab textsum seq2seq 20 -> 4 about
            loss = melt.seq2seq.exact_predict_loss(logits,
                                                   targets,
                                                   mask,
                                                   num_steps,
                                                   batch_size=batch_size)
        elif self.is_predict and exact_loss:
            #force no sample softmax loss, the diff with exact_prob is here we just use cross entropy error as result not real prob of seq
            #NOTICE using time a bit less  55 to 57(prob), same result with exact prob and exact score
            #but 256 vocab sample will use only about 10ms
            loss = melt.seq2seq.sequence_loss_by_example(logits,
                                                         targets,
                                                         weights=mask)
        else:
            #loss [batch_size,]
            loss = melt.seq2seq.sequence_loss_by_example(
                logits,
                targets,
                weights=mask,
                softmax_loss_function=softmax_loss_function)

        #mainly for compat with [bach_size, num_losses]
        loss = tf.reshape(loss, [-1, 1])

        if self.is_predict:
            loss = self.normalize_length(loss, sequence_length, exact_prob)
            #loss = tf.squeeze(loss)  TODO: later will uncomment this with all models rerun
        return loss
示例#5
0
  def sequence_loss(self, input, sequence, initial_state=None, emb=None):
    if emb is None:
      emb = self.emb
    
    is_training = self.is_training
    batch_size = tf.shape(sequence)[0]

    sequence, sequence_length = melt.pad(sequence,
                                     start_id=self.get_start_id(),
                                     end_id=self.get_end_id())

    #TODO different init state as show in ptb_word_lm
    state = self.cell.zero_state(batch_size, tf.float32) if initial_state is None else initial_state

    #[batch_size, num_steps - 1, emb_dim], remove last col
    inputs = tf.nn.embedding_lookup(emb, melt.dynamic_exclude_last_col(sequence))
    
    if is_training and FLAGS.keep_prob < 1:
      inputs = tf.nn.dropout(inputs, FLAGS.keep_prob)
    
    #inputs[batch_size, num_steps, emb_dim] input([batch_size, emb_dim] -> [batch_size, 1, emb_dim]) before concat
    if input is not None:
      #used like showandtell where image_emb is as input, additional to sequence
      inputs = tf.concat(1, [tf.expand_dims(input, 1), inputs])
    else:
      #common usage input is None, sequence as input, notice already pad <GO> before using melt.pad
      sequence_length -= 1
      sequence = sequence[:, 1:]
    
    if self.is_predict:
      #---only need when predict, since train input already dynamic length, NOTICE this will improve speed a lot
      num_steps = tf.cast(tf.reduce_max(sequence_length), dtype=tf.int32)
      sequence = tf.slice(sequence, [0,0], [-1, num_steps])

    outputs, state = tf.nn.dynamic_rnn(self.cell, inputs, 
                                       initial_state=state, 
                                       sequence_length=sequence_length)
    self.final_state = state
    
    if self.softmax_loss_function is None:
      #[batch_size, num_steps, num_units] * [num_units, vocab_size] -> logits [batch_size, num_steps, vocab_size]
      logits = melt.batch_matmul_embedding(outputs, self.w) + self.v
    else:
      logits = outputs

    #[batch_size, num_steps]
    targets = sequence
    mask = tf.cast(tf.sign(sequence), dtype=tf.float32)
    
    if self.is_predict and FLAGS.predict_no_sample:
      loss = melt.seq2seq.exact_predict_loss(logits, batch_size, num_steps)
    else:
      #loss [batch_size,] 
      loss = melt.seq2seq.sequence_loss_by_example(
          logits,
          targets,
          weights=mask,
          softmax_loss_function=self.softmax_loss_function)
    
    #mainly for compat with [bach_size, num_losses]
    loss = tf.reshape(loss, [-1, 1])
    if self.is_predict:
      loss = self.normalize_length(loss, sequence_length)
 
    return loss