Example #1
0
    def update_state(self, y_true, y_pred):
        # y_true: shape [batch_size, sequence_size]
        # y_pred: shape [batch_size, sequence_size]

        prediction_dtype = y_pred.dtype
        prediction_sequence_length = sequence_length_2D(y_pred)
        y_true_tensor = tf.convert_to_tensor(y_true, dtype=prediction_dtype)
        target_sequence_length = sequence_length_2D(y_true_tensor)
        _, masked_corrected_predictions, _, _ = \
            masked_accuracy(y_true, y_pred, target_sequence_length)

        super().update_state(masked_corrected_predictions)
Example #2
0
    def update_state(self, y_true, y_pred):
        # y_true: shape [batch_size, sequence_size]
        # y_pred: shape [batch_size, sequence_size]

        prediction_dtype = y_pred.dtype
        prediction_sequence_length = sequence_length_2D(y_pred)
        y_true_tensor = tf.cast(y_true, dtype=prediction_dtype)
        target_sequence_length = sequence_length_2D(y_true_tensor)
        edit_distance_val, _ = edit_distance(y_true_tensor,
                                             target_sequence_length, y_pred,
                                             prediction_sequence_length)
        super().update_state(edit_distance_val)
Example #3
0
    def decoder_teacher_forcing(
            self,
            encoder_output,
            target=None,
            encoder_end_state=None
    ):
        # ================ Setup ================
        batch_size = encoder_output.shape[0]

        # Prepare target for decoding
        target_sequence_length = sequence_length_2D(target)
        start_tokens = tf.tile([self.GO_SYMBOL], [batch_size])
        end_tokens = tf.tile([self.END_SYMBOL], [batch_size])
        if self.is_timeseries:
            start_tokens = tf.cast(start_tokens, tf.float32)
            end_tokens = tf.cast(end_tokens, tf.float32)
        targets_with_go_and_eos = tf.concat([
            tf.expand_dims(start_tokens, 1),
            target,  # todo tf2: right now cast to tf.int32, fails if tf.int64
            tf.expand_dims(end_tokens, 1)], 1)
        target_sequence_length_with_eos = target_sequence_length + 1

        # Decoder Embeddings
        decoder_emb_inp = self.decoder_embedding(targets_with_go_and_eos)

        # Setting up decoder memory from encoder output
        if self.attention_mechanism is not None:
            encoder_sequence_length = sequence_length_3D(encoder_output)
            self.attention_mechanism.setup_memory(
                encoder_output,
                memory_sequence_length=encoder_sequence_length
            )

        decoder_initial_state = self.build_decoder_initial_state(
            batch_size,
            encoder_state=encoder_end_state,
            dtype=tf.float32
        )

        decoder = tfa.seq2seq.BasicDecoder(
            self.decoder_rnncell,
            sampler=self.sampler,
            output_layer=self.dense_layer
        )

        # BasicDecoderOutput
        outputs, final_state, generated_sequence_lengths = decoder(
            decoder_emb_inp,
            initial_state=decoder_initial_state,
            sequence_length=target_sequence_length_with_eos
        )

        logits = outputs.rnn_output
        mask = tf.sequence_mask(
            generated_sequence_lengths,
            maxlen=logits.shape[1],
            dtype=tf.float32
        )
        logits = logits * mask[:, :, tf.newaxis]
        return logits  # , outputs, final_state, generated_sequence_lengths
Example #4
0
    def call(self, y_true, y_pred):
        # y_true: shape [batch_size, sequence_size]
        # y_pred: shape [batch_size, sequence_size, num_classes]

        y_pred_tensor = y_pred[LOGITS]
        y_true_tensor = tf.cast(y_true, dtype=tf.int64)

        # pad the shorter sequence
        y_pred_seq_len = tf.shape(y_pred_tensor)[1]
        y_true_seq_len = tf.shape(y_true_tensor)[1]

        y_pred_pad_len = tf.maximum(0, y_true_seq_len - y_pred_seq_len)
        y_true_pad_len = tf.maximum(0, y_pred_seq_len - y_true_seq_len)

        y_pred_tensor = tf.pad(y_pred_tensor,
                               [[0, 0], [0, y_pred_pad_len], [0, 0]])
        y_true_tensor = tf.pad(y_true_tensor, [[0, 0], [0, y_true_pad_len]])

        longest_sequence_length = tf.maximum(sequence_length_2D(y_true_tensor),
                                             sequence_length_3D(y_pred_tensor))
        longest_sequence_length += 1  # for EOS
        longest_sequence_length = tf.minimum(longest_sequence_length,
                                             tf.shape(y_true_tensor)[1])
        mask = tf.sequence_mask(longest_sequence_length,
                                maxlen=tf.shape(y_true_tensor)[1],
                                dtype=tf.float32)
        # compute loss based on valid time steps
        loss = self.loss_function(y_true_tensor, y_pred_tensor)
        loss = loss * mask
        loss = tf.reduce_sum(loss) / tf.reduce_sum(mask)
        return loss
Example #5
0
def sequence_sampled_softmax_cross_entropy(targets, train_logits,
                                           decoder_weights, decoder_biases,
                                           num_classes, **loss):
    batch_max_targets_sequence_length = tf.shape(targets)[1]
    targets_sequence_length = sequence_length_2D(tf.cast(targets, tf.int64))
    batch_max_train_logits_sequence_length = tf.shape(train_logits)[1]

    logits_pad_len = tf.maximum(
        0, batch_max_targets_sequence_length -
        batch_max_train_logits_sequence_length)
    targets_pad_len = tf.maximum(
        0, batch_max_train_logits_sequence_length -
        batch_max_targets_sequence_length)

    padded_logits = tf.pad(train_logits, [[0, 0], [0, logits_pad_len], [0, 0]])
    padded_targets = tf.pad(targets, [[0, 0], [0, targets_pad_len]])

    output_exp = tf.cast(tf.reshape(padded_targets, [-1, 1]), tf.int64)
    sampled_values = sample_values_from_classes(
        output_exp, loss['sampler'], num_classes, loss['negative_samples'],
        loss['unique'], loss['class_counts'], loss['distortion'])

    if loss['sampler'] == 'fixed_unigram':
        # regenerate sampled_values structure for specified samplers
        # to handle any zero values in true_expected_count tensor
        sampled_values = FixedUnigramCandidateSampler(
            sampled_values.sampled_candidates,
            # add smoothing constant EPSILON to handle any zero values
            tf.add(sampled_values.true_expected_count, EPSILON),
            sampled_values.sampled_expected_count)

    def _sampled_loss(labels, logits):
        labels = tf.cast(labels, tf.int64)
        labels = tf.reshape(labels, [-1, 1])
        logits = tf.cast(logits, tf.float32)

        return tf.cast(
            tf.nn.sampled_softmax_loss(weights=tf.transpose(decoder_weights),
                                       biases=decoder_biases,
                                       labels=labels,
                                       inputs=logits,
                                       num_sampled=loss['negative_samples'],
                                       num_classes=num_classes,
                                       sampled_values=sampled_values),
            tf.float32)

    train_loss = tfa.seq2seq.sequence_loss(padded_logits,
                                           padded_targets,
                                           tf.sequence_mask(
                                               targets_sequence_length,
                                               tf.shape(padded_targets)[1],
                                               dtype=tf.float32),
                                           average_across_timesteps=True,
                                           average_across_batch=False,
                                           softmax_loss_function=_sampled_loss)

    return train_loss
Example #6
0
    def update_state(self, y_true, y_pred):
        # y_true: shape [batch_size, sequence_size]
        # y_pred: shape [batch_size, sequence_size]

        prediction_dtype = y_pred.dtype
        y_true_tensor = tf.cast(y_true, dtype=prediction_dtype)
        target_sequence_length = sequence_length_2D(y_true_tensor)
        masked_corrected_preds = masked_corrected_predictions(
            y_true_tensor, y_pred, target_sequence_length)

        super().update_state(masked_corrected_preds)
Example #7
0
 def update_state(self, y_true, y_pred, sample_weight=None):
     y_true = tf.cast(y_true, dtype=tf.int64)
     targets_sequence_length = sequence_length_2D(y_true)
     last_targets = tf.gather_nd(
         y_true,
         tf.stack([
             tf.range(tf.shape(y_true)[0]),
             tf.maximum(targets_sequence_length - 1, 0)
         ],
                  axis=1))
     super().update_state(last_targets, y_pred, sample_weight=sample_weight)
Example #8
0
    def _predictions_eval(
            self,
            inputs,  # encoder_output, encoder_output_state
            training=None
    ):
        logits = self.call(inputs, training=training)

        probabilities = tf.nn.softmax(
            logits,
            name='probabilities_{}'.format(self.name)
        )

        predictions = tf.argmax(
            logits,
            -1,
            name='predictions_{}'.format(self.name),
            output_type=tf.int64
        )

        # todo tf2: deal with spurious 0s in predictions
        generated_sequence_lengths = sequence_length_2D(predictions)
        last_predictions = tf.gather_nd(
            predictions,
            tf.stack(
                [tf.range(tf.shape(predictions)[0]),
                 tf.maximum(
                     generated_sequence_lengths - 1,
                     0
                 )],
                axis=1
            ),
            name='last_predictions_{}'.format(self.name)
        )

        # mask logits
        mask = tf.sequence_mask(
            generated_sequence_lengths,
            maxlen=logits.shape[1],
            dtype=tf.float32
        )

        logits = logits * mask[:, :, tf.newaxis]

        return {
            PREDICTIONS: predictions,
            LAST_PREDICTIONS: last_predictions,
            PROBABILITIES: probabilities,
            LOGITS: logits
        }
Example #9
0
    def update_state(self, y_true, y_pred, sample_weight=None):
        # TODO TF2 account for weights
        targets_sequence_length = sequence_length_2D(
            tf.convert_to_tensor(y_true, dtype=tf.int64))
        last_targets = tf.gather_nd(
            y_true,
            tf.stack([
                tf.range(tf.shape(y_true)[0]),
                tf.maximum(targets_sequence_length - 1, 0)
            ],
                     axis=1))

        last_targets = tf.cast(last_targets, dtype=tf.int64)

        super().update_state(last_targets, y_pred)
Example #10
0
    def call(self, y_true, y_pred):
        # y_true: shape [batch_size, sequence_size]
        # y_pred: shape [batch_size, sequence_size, num_classes]

        if self.from_logits:
            y_pred_tensor = y_pred[LOGITS]
        else:
            y_pred_tensor = y_pred[PROBABILITIES]
        y_true_tensor = tf.cast(y_true, dtype=tf.int64)

        # pad the shorter sequence (tensor shape 1)
        y_pred_tensor_len = tf.shape(y_pred_tensor)[1]
        y_true_tensor_len = tf.shape(y_true_tensor)[1]

        y_pred_pad_len = tf.maximum(0, y_true_tensor_len - y_pred_tensor_len)
        y_true_pad_len = tf.maximum(0, y_pred_tensor_len - y_true_tensor_len)

        y_pred_tensor = tf.pad(y_pred_tensor,
                               [[0, 0], [0, y_pred_pad_len], [0, 0]])
        y_true_tensor = tf.pad(y_true_tensor, [[0, 0], [0, y_true_pad_len]])

        y_true_seq_len = sequence_length_2D(y_true_tensor)
        # longest_sequence_length = tf.maximum(y_true_seq_len,
        #                                     sequence_length_3D(y_pred_tensor))
        # longest_sequence_length = tf.minimum(longest_sequence_length,
        #                                     y_true_seq_len)
        # longest_sequence_length += 2  # for EOS

        mask = tf.sequence_mask(
            y_true_seq_len + 1,  # this is for including the eos
            # in case of generator and shouldn't impact
            # negatively in case of tagger
            maxlen=tf.shape(y_true_tensor)[1],
            dtype=tf.float32
        )
        # compute loss based on valid time steps
        loss = self.loss_function(y_true_tensor, y_pred_tensor)
        loss = loss * mask
        loss = tf.reduce_sum(loss) / tf.reduce_sum(mask)
        return loss
Example #11
0
    def call(self, y_true, y_pred):
        # y_true: shape [batch_size, sequence_size]
        # y_pred: shape [batch_size, sequence_size, num_classes]

        y_pred = y_pred[LOGITS]
        y_true = tf.convert_to_tensor(y_true, dtype=tf.int64)

        # pad the shorter sequence
        if y_true.shape[1] > y_pred.shape[1]:
            pad = tf.zeros([
                y_pred.shape[0], y_true.shape[1] - y_pred.shape[1],
                y_pred.shape[2]
            ],
                           dtype=y_pred.dtype)
            y_pred = tf.concat([y_pred, pad], axis=1)
        elif y_pred.shape[1] > y_true.shape[1]:
            pad = tf.zeros([
                y_true.shape[0],
                y_pred.shape[1] - y_true.shape[1],
            ],
                           dtype=y_true.dtype)
            y_true = tf.concat([y_true, pad], axis=1)

        longest_sequence_length = tf.maximum(sequence_length_2D(y_true),
                                             sequence_length_3D(y_pred))
        longest_sequence_length += 1  # for EOS
        longest_sequence_length = tf.minimum(longest_sequence_length,
                                             y_true.shape[1])
        mask = tf.sequence_mask(longest_sequence_length,
                                maxlen=y_true.shape[1],
                                dtype=tf.float32)
        # compute loss based on valid time steps
        loss = self.loss_function(y_true, y_pred)
        loss = loss * mask
        loss = tf.reduce_sum(loss) / tf.reduce_sum(mask)
        return loss
Example #12
0
    def decoder_teacher_forcing(
            self,
            encoder_output,
            target=None,
            encoder_end_state=None
    ):
        # ================ Setup ================
        batch_size = tf.shape(encoder_output)[0]

        # Prepare target for decoding
        target_sequence_length = sequence_length_2D(target)
        start_tokens = tf.tile([self.GO_SYMBOL], [batch_size])
        end_tokens = tf.tile([self.END_SYMBOL], [batch_size])
        if self.is_timeseries:
            start_tokens = tf.cast(start_tokens, tf.float32)
            end_tokens = tf.cast(end_tokens, tf.float32)
        targets_with_go_and_eos = tf.concat([
            tf.expand_dims(start_tokens, 1),
            target,  # right now cast to tf.int32, fails if tf.int64
            tf.expand_dims(end_tokens, 1)], 1)
        target_sequence_length_with_eos = target_sequence_length + 1

        # Decoder Embeddings
        decoder_emb_inp = self.decoder_embedding(targets_with_go_and_eos)

        # Setting up decoder memory from encoder output
        if self.attention_mechanism is not None:
            encoder_sequence_length = sequence_length_3D(encoder_output)
            self.attention_mechanism.setup_memory(
                encoder_output,
                memory_sequence_length=encoder_sequence_length
            )

        decoder_initial_state = self.build_decoder_initial_state(
            batch_size,
            encoder_state=encoder_end_state,
            dtype=tf.float32
        )

        # use Ludwig custom BasicDecoder
        decoder = BasicDecoder(
            self.decoder_rnncell,
            sampler=self.sampler,
            output_layer=self.dense_layer
        )

        # BasicDecoderOutput
        outputs, final_state, generated_sequence_lengths = decoder(
            decoder_emb_inp,
            initial_state=decoder_initial_state,
            sequence_length=target_sequence_length_with_eos
        )

        logits = outputs.rnn_output
        # mask = tf.sequence_mask(
        #    generated_sequence_lengths,
        #    maxlen=tf.shape(logits)[1],
        #    dtype=tf.float32
        # )
        # logits = logits * mask[:, :, tf.newaxis]

        # append a trailing 0, useful for
        # those datapoints that reach maximum length
        # and don't have a eos at the end
        logits = tf.pad(
            logits,
            [[0, 0], [0, 1], [0, 0]]
        )

        # EXPECTED SIZE OF RETURNED TENSORS
        # logits: shape[batch_size, seq_size, num_classes] used for evaluation
        # projection_input: shape[batch_size, seq_size, state_size] for sampled softmax
        return {
            LOGITS: logits,
            PROJECTION_INPUT: outputs.projection_input
        }
    def __call__(self,
                 output_feature,
                 targets,
                 hidden,
                 hidden_size,
                 regularizer,
                 is_timeseries=False):
        logging.info('  hidden shape: {0}'.format(hidden.shape))
        if len(hidden.shape) != 3:
            raise ValueError(
                'Decoder inputs rank is {}, but should be 3 [batch x sequence x hidden] '
                'when using a tagger sequential decoder. '
                'Consider setting reduce_output to null / None if a sequential encoder / combiner is used.'
                .format(len(hidden.shape)))

        if is_timeseries:
            output_feature['num_classes'] = 1

        if not self.regularize:
            regularizer = None

        sequence_length = tf.shape(hidden)[1]

        if self.attention:
            hidden, hidden_size = feed_forward_memory_attention(
                hidden, hidden, hidden_size)
        targets_sequence_length = sequence_length_2D(targets)

        initializer_obj = get_initializer(self.initializer)
        class_weights = tf.get_variable('weights',
                                        initializer=initializer_obj([
                                            hidden_size,
                                            output_feature['num_classes']
                                        ]),
                                        regularizer=regularizer)
        logging.debug('  weights: {0}'.format(class_weights))

        class_biases = tf.get_variable('biases',
                                       [output_feature['num_classes']])
        logging.debug('  biases: {0}'.format(class_biases))

        hidden_reshape = tf.reshape(hidden, [-1, hidden_size])
        logits_to_reshape = tf.matmul(hidden_reshape,
                                      class_weights) + class_biases
        logits = tf.reshape(
            logits_to_reshape,
            [-1, sequence_length, output_feature['num_classes']])
        logging.debug('  logits: {0}'.format(logits))

        if is_timeseries:
            probabilities_sequence = tf.zeros_like(logits)
            predictions_sequence = tf.reshape(logits, [-1, sequence_length])
        else:
            probabilities_sequence = tf.nn.softmax(
                logits, name='probabilities_{}'.format(output_feature['name']))
            predictions_sequence = tf.argmax(logits,
                                             -1,
                                             name='predictions_{}'.format(
                                                 output_feature['name']),
                                             output_type=tf.int32)

        predictions_sequence_length = sequence_length_3D(hidden)

        return predictions_sequence, probabilities_sequence, \
               predictions_sequence_length, \
               probabilities_sequence, targets_sequence_length, \
               logits, hidden, class_weights, class_biases
Example #14
0
def recurrent_decoder(encoder_outputs,
                      targets,
                      max_sequence_length,
                      vocab_size,
                      cell_type='rnn',
                      state_size=256,
                      embedding_size=50,
                      num_layers=1,
                      attention_mechanism=None,
                      beam_width=1,
                      projection=True,
                      tied_target_embeddings=True,
                      embeddings=None,
                      initializer=None,
                      regularizer=None,
                      is_timeseries=False):
    with tf.variable_scope('rnn_decoder',
                           reuse=tf.AUTO_REUSE,
                           regularizer=regularizer):

        # ================ Setup ================
        if beam_width > 1 and is_timeseries:
            raise ValueError('Invalid beam_width: {}'.format(beam_width))

        GO_SYMBOL = vocab_size
        END_SYMBOL = 0
        batch_size = tf.shape(encoder_outputs)[0]

        # ================ Projection ================
        # Project the encoder outputs to the size of the decoder state
        encoder_outputs_size = encoder_outputs.shape[-1]
        if projection and encoder_outputs_size != state_size:
            with tf.variable_scope('projection'):
                encoder_output_rank = len(encoder_outputs.shape)
                if encoder_output_rank > 2:
                    sequence_length = tf.shape(encoder_outputs)[1]
                    encoder_outputs = tf.reshape(encoder_outputs,
                                                 [-1, encoder_outputs_size])
                    encoder_outputs = fc_layer(encoder_outputs,
                                               encoder_outputs.shape[-1],
                                               state_size,
                                               activation=None,
                                               initializer=initializer)
                    encoder_outputs = tf.reshape(
                        encoder_outputs, [-1, sequence_length, state_size])
                else:
                    encoder_outputs = fc_layer(encoder_outputs,
                                               encoder_outputs.shape[-1],
                                               state_size,
                                               activation=None,
                                               initializer=initializer)

        # ================ Targets sequence ================
        # Calculate the length of inputs and the batch size
        with tf.variable_scope('sequence'):
            targets_sequence_length = sequence_length_2D(targets)
            start_tokens = tf.tile([GO_SYMBOL], [batch_size])
            end_tokens = tf.tile([END_SYMBOL], [batch_size])
            if is_timeseries:
                start_tokens = tf.cast(start_tokens, tf.float32)
                end_tokens = tf.cast(end_tokens, tf.float32)
            targets_with_go_and_eos = tf.concat([
                tf.expand_dims(start_tokens, 1), targets,
                tf.expand_dims(end_tokens, 1)
            ], 1)
            logging.debug(
                '  targets_with_go: {0}'.format(targets_with_go_and_eos))
            targets_sequence_length_with_eos = targets_sequence_length + 1  # the EOS symbol is 0 so it's not increasing the real length of the sequence

        # ================ Embeddings ================
        if is_timeseries:
            targets_embedded = tf.expand_dims(targets_with_go_and_eos, -1)
            targets_embeddings = None
        else:
            with tf.variable_scope('embedding'):
                if embeddings is not None:
                    embedding_size = embeddings.shape.as_list()[-1]
                    if tied_target_embeddings:
                        state_size = embedding_size
                elif tied_target_embeddings:
                    embedding_size = state_size

                if embeddings is not None:
                    embedding_go = tf.get_variable(
                        'embedding_GO',
                        initializer=tf.random_uniform([1, embedding_size],
                                                      -1.0, 1.0))
                    targets_embeddings = tf.concat([embeddings, embedding_go],
                                                   axis=0)
                else:
                    initializer_obj = get_initializer(initializer)
                    targets_embeddings = tf.get_variable(
                        'embeddings',
                        initializer=initializer_obj(
                            [vocab_size + 1, embedding_size]),
                        regularizer=regularizer)
                logging.debug(
                    '  targets_embeddings: {0}'.format(targets_embeddings))

                targets_embedded = tf.nn.embedding_lookup(
                    targets_embeddings,
                    targets_with_go_and_eos,
                    name='decoder_input_embeddings')
        logging.debug('  targets_embedded: {0}'.format(targets_embedded))

        # ================ Class prediction ================
        if tied_target_embeddings:
            class_weights = tf.transpose(targets_embeddings)
        else:
            initializer_obj = get_initializer(initializer)
            class_weights = tf.get_variable('class_weights',
                                            initializer=initializer_obj(
                                                [state_size, vocab_size + 1]),
                                            regularizer=regularizer)
        logging.debug('  class_weights: {0}'.format(class_weights))
        class_biases = tf.get_variable('class_biases', [vocab_size + 1])
        logging.debug('  class_biases: {0}'.format(class_biases))
        projection_layer = Projection(class_weights, class_biases)

        # ================ RNN ================
        initial_state = encoder_outputs
        with tf.variable_scope('rnn_cells') as vs:
            # Cell
            cell_fun = get_cell_fun(cell_type)

            if num_layers == 1:
                cell = cell_fun(state_size)
                if cell_type.startswith('lstm'):
                    initial_state = LSTMStateTuple(c=initial_state,
                                                   h=initial_state)
            elif num_layers > 1:
                cell = MultiRNNCell(
                    [cell_fun(state_size) for _ in range(num_layers)],
                    state_is_tuple=True)
                if cell_type.startswith('lstm'):
                    initial_state = LSTMStateTuple(c=initial_state,
                                                   h=initial_state)
                initial_state = tuple([initial_state] * num_layers)
            else:
                raise ValueError(
                    'num_layers in recurrent decoser: {}. '
                    'Number of layers in a recurrenct decoder cannot be <= 0'.
                    format(num_layers))

            # Attention
            if attention_mechanism is not None:
                if attention_mechanism == 'bahdanau':
                    attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
                        num_units=state_size,
                        memory=encoder_outputs,
                        memory_sequence_length=sequence_length_3D(
                            encoder_outputs))
                elif attention_mechanism == 'luong':
                    attention_mechanism = tf.contrib.seq2seq.LuongAttention(
                        num_units=state_size,
                        memory=encoder_outputs,
                        memory_sequence_length=sequence_length_3D(
                            encoder_outputs))
                else:
                    raise ValueError(
                        'Attention mechanism {} not supported'.format(
                            attention_mechanism))
                cell = tf.contrib.seq2seq.AttentionWrapper(
                    cell, attention_mechanism, attention_layer_size=state_size)
                initial_state = cell.zero_state(dtype=tf.float32,
                                                batch_size=batch_size)
                initial_state = initial_state.clone(
                    cell_state=reduce_sequence(encoder_outputs, 'last'))

            for v in tf.global_variables():
                if v.name.startswith(vs.name):
                    logging.debug('  {}: {}'.format(v.name, v))

        # ================ Decoding ================
        def decode(initial_state,
                   cell,
                   helper,
                   beam_width=1,
                   projection_layer=None):
            # The decoder itself
            if beam_width > 1:
                # Tile inputs for beam search decoder
                beam_initial_state = tf.contrib.seq2seq.tile_batch(
                    initial_state, beam_width)
                decoder = tf.contrib.seq2seq.BeamSearchDecoder(
                    cell=cell,
                    embedding=targets_embeddings,
                    start_tokens=start_tokens,
                    end_token=END_SYMBOL,
                    initial_state=beam_initial_state,
                    beam_width=beam_width,
                    output_layer=projection_layer)
            else:
                decoder = BasicDecoder(cell=cell,
                                       helper=helper,
                                       initial_state=initial_state,
                                       output_layer=projection_layer)

            # The decoding operation
            outputs = tf.contrib.seq2seq.dynamic_decode(
                decoder=decoder,
                output_time_major=False,
                impute_finished=False if beam_width > 1 else True,
                maximum_iterations=max_sequence_length)

            return outputs

        # ================ Decoding helpers ================
        if is_timeseries:
            train_helper = TimeseriesTrainingHelper(
                inputs=targets_embedded,
                sequence_length=targets_sequence_length_with_eos)
            final_outputs_pred, final_state_pred, final_sequence_lengths_pred = decode(
                initial_state,
                cell,
                train_helper,
                projection_layer=projection_layer)
            eval_logits = final_outputs_pred.rnn_output
            train_logits = final_outputs_pred.projection_input
            predictions_sequence = tf.reshape(eval_logits, [batch_size, -1])
            predictions_sequence_length_with_eos = final_sequence_lengths_pred

        else:
            train_helper = tf.contrib.seq2seq.TrainingHelper(
                inputs=targets_embedded,
                sequence_length=targets_sequence_length_with_eos)
            final_outputs_train, final_state_train, final_sequence_lengths_train = decode(
                initial_state,
                cell,
                train_helper,
                projection_layer=projection_layer)
            eval_logits = final_outputs_train.rnn_output
            train_logits = final_outputs_train.projection_input
            # train_predictions = final_outputs_train.sample_id

            pred_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                embedding=targets_embeddings,
                start_tokens=start_tokens,
                end_token=END_SYMBOL)
            final_outputs_pred, final_state_pred, final_sequence_lengths_pred = decode(
                initial_state,
                cell,
                pred_helper,
                beam_width,
                projection_layer=projection_layer)

            if beam_width > 1:
                predictions_sequence = final_outputs_pred.beam_search_decoder_output.predicted_ids[:, :,
                                                                                                   0]
                # final_outputs_pred..predicted_ids[:,:,0] would work too, but it contains -1s for padding
                predictions_sequence_scores = final_outputs_pred.beam_search_decoder_output.scores[:, :,
                                                                                                   0]
                predictions_sequence_length_with_eos = final_sequence_lengths_pred[:,
                                                                                   0]
            else:
                predictions_sequence = final_outputs_pred.sample_id
                predictions_sequence_scores = final_outputs_pred.rnn_output
                predictions_sequence_length_with_eos = final_sequence_lengths_pred

    logging.debug('  train_logits: {0}'.format(train_logits))
    logging.debug('  eval_logits: {0}'.format(eval_logits))
    logging.debug('  predictions_sequence: {0}'.format(predictions_sequence))
    logging.debug('  predictions_sequence_scores: {0}'.format(
        predictions_sequence_scores))

    return predictions_sequence, predictions_sequence_scores, predictions_sequence_length_with_eos, \
           targets_sequence_length_with_eos, eval_logits, train_logits, class_weights, class_biases