Exemplo n.º 1
0
def compile_train_fn(model, learning_rate=2e-4):
    """ Build the CTC training routine for speech models.
    Args:
        model: A keras model (built=True) instance
    Returns:
        train_fn (theano.function): Function that takes in acoustic inputs,
            and updates the model. Returns network outputs and ctc cost
    """
    logger.info("Building train_fn")
    acoustic_input = model.inputs[0]
    network_output = model.outputs[0]
    output_lens = K.placeholder(ndim=1, dtype='int32')
    label = K.placeholder(ndim=1, dtype='int32')
    label_lens = K.placeholder(ndim=1, dtype='int32')
    network_output = network_output.dimshuffle((1, 0, 2))

    ctc_cost = ctc.cpu_ctc_th(network_output, output_lens, label,
                              label_lens).mean()
    trainable_vars = model.trainable_weights
    optimizer = SGD(nesterov=True,
                    lr=learning_rate,
                    momentum=0.9,
                    clipnorm=100)
    updates = optimizer.get_updates(trainable_vars, [], ctc_cost)
    train_fn = K.function(
        [acoustic_input, output_lens, label, label_lens,
         K.learning_phase()], [network_output, ctc_cost],
        updates=updates)
    return train_fn
Exemplo n.º 2
0
    def make_train_fn(self, model, learning_rate, momentum, optimizer,
                      regmask_start_val, regmask_anneal_episodes):
        assert (regmask_start_val is None) == (regmask_anneal_episodes is None)

        if optimizer == 'sgd':
            opt = SGD(lr=learning_rate, momentum=momentum)
        elif optimizer == 'adam':
            opt = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8)
        else:
            raise ValueError('Unrecognized optimizer: {}'.format(optimizer))
        state_input = model.input
        action_input = Input(batch_shape=(None, ),
                             dtype=np.int32,
                             name='action_input')
        reward_input = Input(batch_shape=(None, ),
                             dtype=np.float32,
                             name='reward_input')
        episode = tf.placeholder(dtype=tf.int32, shape=(), name='episode')

        action_probas = model(state_input)
        indexer = tf.stack(
            [tf.range(0,
                      tf.shape(action_input)[0]), action_input], axis=1)
        selected_action_probas = tf.gather_nd(params=action_probas,
                                              indices=indexer)

        mean_entropy = K.mean(
            -K.sum(action_probas * K.log(action_probas + 1e-6), axis=1))
        objectives = reward_input * K.log(selected_action_probas)
        entropy_loss_decay_coefficient = K.exp(
            -self.entropy_boost_decay * tf.cast(episode, dtype=tf.float32))
        undiscounted_entropy_loss = -K.log(mean_entropy + self.EPSILON)
        entropy_loss = entropy_loss_decay_coefficient * self.entropy_boost * undiscounted_entropy_loss

        regmask_loss = _get_regmask_loss(regmask_start_val,
                                         regmask_anneal_episodes,
                                         self.state_shape, self.action_space.n,
                                         model.layers[0].kernel, episode)

        loss = (-K.sum(objectives) + entropy_loss + regmask_loss)
        updates = opt.get_updates(loss=loss, params=model.trainable_weights)

        # --- Statistics for monitoring purposes only
        grads_overall = opt.get_gradients(loss=loss,
                                          params=model.trainable_weights)
        grads_entropy = opt.get_gradients(loss=entropy_loss,
                                          params=model.trainable_weights)

        # We assume a linear model with weight-matrix and bias-vector
        assert len(grads_overall) == len(grads_entropy) == 2
        W_grad_length_overall = K.sqrt(K.sum(K.square(grads_overall[0])))
        W_grad_length_entropy = K.sqrt(K.sum(K.square(grads_entropy[0])))

        # weight_variance = K.var(model.trainable_weights[0])
        # -------------------------------

        train_fn = K.function(
            inputs=[state_input, action_input, reward_input, episode],
            outputs=[
                entropy_loss, undiscounted_entropy_loss, W_grad_length_overall,
                W_grad_length_entropy, regmask_loss
            ],
            updates=updates)
        return train_fn