예제 #1
0
    def compute_loss(self, targets, logits, logit_seq_length,
                     target_seq_length):
        '''
        Compute the loss

        Creates the operation to compute the cross-enthropy loss for every input
        frame (if you want to have a different loss function, overwrite this
        method)

        Args:
            targets: a [batch_size, max_target_length] tensor containing the
                targets
            logits: a [batch_size, max_logit_length, dim] tensor containing the
                logits
            logit_seq_length: the length of all the logit sequences as a
                [batch_size] vector
            target_seq_length: the length of all the target sequences as a
                [batch_size] vector

        Returns:
            a scalar value containing the loss
        '''

        with tf.name_scope('cross_enthropy_loss'):
            output_dim = int(logits.get_shape()[2])

            #put all the tragets on top of each other
            split_targets = tf.unstack(targets)
            for i, target in enumerate(split_targets):
                #only use the real data
                split_targets[i] = target[:target_seq_length[i]]

                #append an end of sequence label
                split_targets[i] = tf.concat(
                    [split_targets[i], [output_dim - 1]], 0)

            #concatenate the targets
            nonseq_targets = tf.concat(split_targets, 0)

            #convert the logits to non sequential data
            nonseq_logits = ops.seq2nonseq(logits, logit_seq_length)

            #one hot encode the targets
            #pylint: disable=E1101
            nonseq_targets = tf.one_hot(nonseq_targets, output_dim)
            '''#collect the attention tensor, with shape [batch_size, sequence_length, output_dim]
            attention = tf.get_collection('attention')
            attention = attention[0]'''

            #compute the cross-enthropy loss
            loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(logits=nonseq_logits,
                                                        labels=nonseq_targets))
            #+ 0.001*tf.reduce_mean(attention_normalize(attention))
            #+ 0.05*tf.reduce_mean(attention_prior(attention))

        return loss
예제 #2
0
    def compute_loss(self, targets, logits, logit_seq_length,
                     target_seq_length):
        '''
        Compute the loss

        Creates the operation to compute the CTC loss for every input
        frame (if you want to have a different loss function, overwrite this
        method)

        Args:
            targets: a tupple of targets, the first one being a
                [batch_size, max_target_length] tensor containing the real
                targets, the second one being a [batch_size, max_audioseq_length]
                tensor containing the audio samples or other extra information.
            logits: a [batch_size, max_logit_length, dim] tensor containing the
                logits
            logit_seq_length: the length of all the logit sequences as a
                [batch_size] vector
            target_seq_length: the length of all the target sequences as a
                tupple of two [batch_size] vectors, both for one of the elements
                in the targets tupple

        Returns:
            a scalar value containing the loss
        '''

        with tf.name_scope('CTC_loss'):

            #get the batch size
            targets = tf.expand_dims(targets[0], 2)
            batch_size = int(targets.get_shape()[0])

            #convert the targets into a sparse tensor representation
            indices = tf.concat([
                tf.concat([
                    tf.expand_dims(tf.tile([s], [target_seq_length[s]]), 1),
                    tf.expand_dims(tf.range(target_seq_length[s]), 1)
                ], 1) for s in range(batch_size)
            ], 0)

            values = tf.reshape(ops.seq2nonseq(targets, target_seq_length),
                                [-1])

            shape = [batch_size, int(targets.get_shape()[1])]

            sparse_targets = tf.SparseTensor(tf.cast(indices, tf.int64),
                                             values, shape)

            loss = tf.reduce_mean(
                tf.nn.ctc_loss(sparse_targets,
                               logits,
                               logit_seq_length,
                               time_major=False))

        return loss
예제 #3
0
파일: layer.py 프로젝트: r0371984/nabu
    def __call__(self,
                 inputs,
                 seq_length,
                 causal=False,
                 is_training=False,
                 scope=None):
        '''
        Create the variables and do the forward computation

        Args:
            inputs: the input to the layer as a
                [batch_size, max_length, dim] tensor
            seq_length: the length of the input sequences
            causal: flag for causality, if true every output will only be
                affected by previous inputs
            is_training: whether or not the network is in training mode
            scope: The variable scope sets the namespace under which
                the variables created during this call will be stored.

        Returns:
            the outputs which is a [batch_size, max_length/stride, num_units]
        '''

        with tf.variable_scope(scope or type(self).__name__):

            input_dim = int(inputs.get_shape()[2])
            stddev = 1 / input_dim**0.5

            #the filter parameters
            w = tf.get_variable(
                'filter', [self.kernel_size, input_dim, self.num_units],
                initializer=tf.random_normal_initializer(stddev=stddev))

            #the bias parameters
            b = tf.get_variable(
                'bias', [self.num_units],
                initializer=tf.random_normal_initializer(stddev=stddev))

            #do the atrous convolution
            if causal:
                out = ops.causal_aconv1d(inputs, w, self.dilation_rate)
            else:
                out = ops.aconv1d(inputs, w, self.dilation_rate)

            #add the bias
            out = ops.seq2nonseq(out, seq_length)
            out += b
            out = ops.nonseq2seq(out, seq_length, int(inputs.get_shape()[1]))

        return out