Beispiel #1
0
def model_fn(features, labels, mode, params):
    input_ids = features['text_ids']
    input_lengths = features['len_text_ids'][:, 0]
    speaker_ids = tf.constant([0], dtype=tf.int32)
    mel_outputs = features['mel']
    mel_lengths = features['len_mel'][:, 0]
    guided = features['g']
    stop_token_target = features['stop_token_target']
    batch_size = tf.shape(guided)[0]

    model = tacotron2.Model(
        [input_ids, input_lengths],
        [mel_outputs, mel_lengths],
        len(MALAYA_SPEECH_SYMBOLS),
    )
    r = model.decoder_logits['outputs']
    decoder_output, post_mel_outputs, alignment_histories, _, _, _ = r
    stop_token_predictions = model.decoder_logits['stop_token_prediction']

    stop_token = tf.expand_dims(stop_token_target, -1)
    max_length = tf.cast(tf.shape(decoder_output)[1], tf.int32)

    loss_f = tf.losses.mean_squared_error
    mask = tf.sequence_mask(lengths=mel_lengths,
                            maxlen=max_length,
                            dtype=tf.float32)
    mask = tf.expand_dims(mask, axis=-1)

    mel_loss_before = loss_f(labels=mel_outputs,
                             predictions=decoder_output,
                             weights=mask)
    mel_loss_after = loss_f(labels=mel_outputs,
                            predictions=post_mel_outputs,
                            weights=mask)
    stop_token_loss = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=stop_token, logits=stop_token_predictions)
    stop_token_loss = stop_token_loss * mask
    stop_token_loss = tf.reduce_sum(stop_token_loss) / tf.reduce_sum(mask)

    attention_masks = tf.cast(tf.math.not_equal(guided, -1.0), tf.float32)
    loss_att = tf.reduce_sum(tf.abs(alignment_histories * guided) *
                             attention_masks,
                             axis=[1, 2])
    loss_att /= tf.reduce_sum(attention_masks, axis=[1, 2])
    loss_att = tf.reduce_mean(loss_att)

    loss = stop_token_loss + mel_loss_before + mel_loss_after + loss_att

    tf.identity(loss, 'loss')
    tf.identity(stop_token_loss, name='stop_token_loss')
    tf.identity(mel_loss_before, name='mel_loss_before')
    tf.identity(mel_loss_after, name='mel_loss_after')
    tf.identity(loss_att, name='loss_att')

    tf.summary.scalar('stop_token_loss', stop_token_loss)
    tf.summary.scalar('mel_loss_before', mel_loss_before)
    tf.summary.scalar('mel_loss_after', mel_loss_after)
    tf.summary.scalar('loss_att', loss_att)

    if mode == tf.estimator.ModeKeys.TRAIN:
        train_op = train.optimizer.optimize_loss(
            loss,
            tf.train.AdamOptimizer,
            parameters['optimizer_params'],
            learning_rate_scheduler,
            summaries=['learning_rate'],
            larc_params=parameters.get('larc_params', None),
            loss_scaling=parameters.get('loss_scaling', 1.0),
            loss_scaling_params=parameters.get('loss_scaling_params', None),
            clip_gradients=parameters.get('max_grad_norm', None),
        )
        estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                    loss=loss,
                                                    train_op=train_op)

    elif mode == tf.estimator.ModeKeys.EVAL:

        estimator_spec = tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.EVAL, loss=loss)

    return estimator_spec
def model_fn(features, labels, mode, params):
    input_ids = features['text_ids']
    input_lengths = features['len_text_ids'][:, 0]
    speaker_ids = tf.constant([0], dtype=tf.int32)
    mel_outputs = features['mel']
    mel_lengths = features['len_mel'][:, 0]
    guided = features['g']

    model = tacotron2.Model(
        [input_ids, input_lengths],
        [mel_outputs, mel_lengths],
        len(MALAYA_SPEECH_SYMBOLS),
    )

    r = model.decoder_logits['outputs']
    decoder_output, post_mel_outputs, alignment_histories, _, _, _ = r
    stop_token_predictions = model.decoder_logits['stop_token_prediction']
    stop_token_predictions = stop_token_predictions[:, :, 0]

    binary_crossentropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    mae = tf.keras.losses.MeanAbsoluteError()

    mel_loss_before = calculate_3d_loss(mel_outputs,
                                        decoder_output,
                                        loss_fn=mae)
    mel_loss_after = calculate_3d_loss(mel_outputs,
                                       post_mel_outputs,
                                       loss_fn=mae)
    max_mel_length = tf.reduce_max(mel_lengths)
    stop_gts = tf.expand_dims(
        tf.range(tf.reduce_max(max_mel_length), dtype=tf.int32), 0)
    stop_gts = tf.tile(stop_gts, [tf.shape(mel_lengths)[0], 1])
    stop_gts = tf.cast(
        tf.math.greater_equal(stop_gts, tf.expand_dims(mel_lengths, 1)),
        tf.float32,
    )
    stop_token_loss = calculate_2d_loss(stop_gts,
                                        stop_token_predictions,
                                        loss_fn=binary_crossentropy)
    attention_masks = tf.cast(tf.math.not_equal(guided, -1.0), tf.float32)
    loss_att = tf.reduce_sum(tf.abs(alignment_histories * guided) *
                             attention_masks,
                             axis=[1, 2])
    loss_att /= tf.reduce_sum(attention_masks, axis=[1, 2])
    loss_att = tf.reduce_mean(loss_att)

    loss = stop_token_loss + mel_loss_before + mel_loss_after + loss_att

    tf.identity(loss, 'loss')
    tf.identity(stop_token_loss, name='stop_token_loss')
    tf.identity(mel_loss_before, name='mel_loss_before')
    tf.identity(mel_loss_after, name='mel_loss_after')
    tf.identity(loss_att, name='loss_att')

    tf.summary.scalar('stop_token_loss', stop_token_loss)
    tf.summary.scalar('mel_loss_before', mel_loss_before)
    tf.summary.scalar('mel_loss_after', mel_loss_after)
    tf.summary.scalar('loss_att', loss_att)

    if mode == tf.estimator.ModeKeys.TRAIN:
        train_op = train.optimizer.optimize_loss(
            loss,
            tf.train.AdamOptimizer,
            parameters['optimizer_params'],
            learning_rate_scheduler,
            summaries=[
                'learning_rate',
                'variables',
                'gradients',
                'larc_summaries',
                'variable_norm',
                'gradient_norm',
                'global_gradient_norm',
            ],
            larc_params=parameters.get('larc_params', None),
            loss_scaling=parameters.get('loss_scaling', 1.0),
            loss_scaling_params=parameters.get('loss_scaling_params', None),
            clip_gradients=parameters.get('max_grad_norm', None),
        )
        estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                    loss=loss,
                                                    train_op=train_op)

    elif mode == tf.estimator.ModeKeys.EVAL:

        estimator_spec = tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.EVAL, loss=loss)

    return estimator_spec