示例#1
0
def compile_train_fn(model, learning_rate=2e-4):
    """ Build the CTC training routine for speech models.
    Args:
        model: A keras model (built=True) instance
    Returns:
        train_fn (theano.function): Function that takes in acoustic inputs,
            and updates the model. Returns network outputs and ctc cost
    """
    logger.info("Building train_fn")
    acoustic_input = model.inputs[0]
    network_output = model.outputs[0]
    output_lens = K.placeholder(ndim=1, dtype='int32')
    label = K.placeholder(ndim=1, dtype='int32')
    label_lens = K.placeholder(ndim=1, dtype='int32')
    network_output = network_output.dimshuffle((1, 0, 2))

    ctc_cost = ctc.cpu_ctc_th(network_output, output_lens, label,
                              label_lens).mean()
    trainable_vars = model.trainable_weights
    # optimizer = SGD(nesterov=True, lr=learning_rate, momentum=0.9,
    #                 clipnorm=100)
    # updates = optimizer.get_updates(trainable_vars, [], ctc_cost)
    trainable_vars = model.trainable_weights
    grads = K.gradients(ctc_cost, trainable_vars)
    grads = lasagne.updates.total_norm_constraint(grads, 100)
    updates = lasagne.updates.nesterov_momentum(grads, trainable_vars,
                                                learning_rate, 0.99)
    train_fn = K.function(
        [acoustic_input, output_lens, label, label_lens,
         K.learning_phase()], [network_output, ctc_cost],
        updates=updates)
    return train_fn
示例#2
0
 def __init__(self, model):
     self._yhat = T.vector(name='yhat', dtype='int32')
     self._ylen = T.vector(name='ylen', dtype='int32')
     self._model = model
     self._loss = ctc.cpu_ctc_th(model.output, model.output_size,
                                 self._yhat, self._ylen).sum()
     self._callf = theano.function(inputs=self.input, outputs=self.output)
示例#3
0
文件: simple.py 项目: noammor/ctc
def create_theano_func():
    acts = T.ftensor3()
    act_lens = T.ivector()
    labels = T.ivector()
    label_lens = T.ivector()
    costs = cpu_ctc_th(acts, act_lens, labels, label_lens)
    cost = T.mean(costs)
    grads = T.grad(cost, acts)
    f = theano.function([acts, act_lens, labels, label_lens], cost, allow_input_downcast=True)
    g = theano.function([acts, act_lens, labels, label_lens], grads, allow_input_downcast=True)
    return f, g
示例#4
0
def create_theano_func():
    acts = T.ftensor3()
    act_lens = T.ivector()
    labels = T.ivector()
    label_lens = T.ivector()
    costs = cpu_ctc_th(acts, act_lens, labels, label_lens)
    cost = T.mean(costs)
    grads = T.grad(cost, acts)
    f = theano.function([acts, act_lens, labels, label_lens],
                        cost,
                        allow_input_downcast=True)
    g = theano.function([acts, act_lens, labels, label_lens],
                        grads,
                        allow_input_downcast=True)
    return f, g
示例#5
0
def compile_test_fn(model):
    """ Build a testing routine for speech models.
    Args:
        model: A keras model (built=True) instance
    Returns:
        val_fn (theano.function): Function that takes in acoustic inputs,
            and calculates the loss. Returns network outputs and ctc cost
    """
    logger.info("Building val_fn")
    acoustic_input = model.inputs[0]
    network_output = model.outputs[0]
    output_lens = K.placeholder(ndim=1, dtype='int32')
    label = K.placeholder(ndim=1, dtype='int32')
    label_lens = K.placeholder(ndim=1, dtype='int32')
    network_output = network_output.dimshuffle((1, 0, 2))

    ctc_cost = ctc.cpu_ctc_th(network_output, output_lens,
                              label, label_lens).mean()
    val_fn = K.function([acoustic_input, output_lens, label, label_lens,
                        K.learning_phase()],
                        [network_output, ctc_cost])
    return val_fn
示例#6
0
    h1,
    n_hidden,
    grad_clipping=grad_clip,
    nonlinearity=lasagne.nonlinearities.rectify,
    backwards=True)
h2 = lasagne.layers.ElemwiseSumLayer([h2f, h2b])

h3 = lasagne.layers.RecurrentLayer(h2,
                                   num_classes,
                                   grad_clipping=grad_clip,
                                   nonlinearity=lasagne.nonlinearities.linear)
l_out = lasagne.layers.ReshapeLayer(h3, ((max_len, mbsz, num_classes)))

network_output = lasagne.layers.get_output(l_out)

cost = T.mean(ctc.cpu_ctc_th(network_output, input_lens, output, output_lens))
grads = T.grad(cost, wrt=network_output)
all_params = lasagne.layers.get_all_params(l_out)
updates = lasagne.updates.adam(cost, all_params, 0.001)

train = theano.function([l_in.input_var, input_lens, output, output_lens],
                        cost,
                        updates=updates)
predict = theano.function([l_in.input_var], network_output)
get_grad = theano.function([l_in.input_var, input_lens, output, output_lens],
                           grads)

from loader import DataLoader

data_loader = DataLoader(mbsz=mbsz,
                         min_len=min_len,
def train(step_rule, label_dim, state_dim, epochs,
          seed, dropout, test_cost, experiment_path, features, weight_noise,
          to_watch, patience, batch_size, batch_norm, **kwargs):

    print '.. TIMIT experiment'
    print '.. arguments:', ' '.join(sys.argv)
    t0 = time.time()


    # ------------------------------------------------------------------------
    # Streams

    rng = np.random.RandomState(seed)
    stream_args = dict(rng=rng, batch_size=batch_size)

    print '.. initializing iterators'
    train_dataset = Timit('train', features=features)
    train_stream = construct_stream(train_dataset, **stream_args)
    dev_dataset = Timit('dev', features=features)
    dev_stream = construct_stream(dev_dataset, **stream_args)
    test_dataset = Timit('test', features=features)
    test_stream = construct_stream(test_dataset, **stream_args)
    update_stream = construct_stream(train_dataset, n_batches=100,
                                     **stream_args)

    phone_dict = train_dataset.get_phoneme_dict()
    phoneme_dict = {k: phone_to_phoneme_dict[v]
                    if v in phone_to_phoneme_dict else v
                    for k, v in phone_dict.iteritems()}
    ind_to_phoneme = {v: k for k, v in phoneme_dict.iteritems()}
    eol_symbol = ind_to_phoneme['<STOP>']
 
   
    # ------------------------------------------------------------------------
    # Graph

    print '.. building model'
    x = T.tensor3('features')
    y = T.matrix('phonemes')
    input_mask = T.matrix('features_mask')
    output_mask = T.matrix('phonemes_mask')

    theano.config.compute_test_value = 'off'
    x.tag.test_value = np.random.randn(100, 24, 123).astype(floatX)
    y.tag.test_value = np.ones((30, 24), dtype=floatX)
    input_mask.tag.test_value = np.ones((100, 24), dtype=floatX)
    output_mask.tag.test_value = np.ones((30, 24), dtype=floatX)

    seq_len = 100 
    input_dim = 123 
    activation = Tanh()
    recurrent_init = IdentityInit(0.99) 

    if batch_norm:
        rec1 = LSTMBatchNorm(name='rec1',
                             dim=state_dim,
                             activation=activation,
                             weights_init=NormalizedInitialization())
        #rec1 = SimpleRecurrentBatchNorm(name='rec1',
        #                                dim=state_dim,
        #                                activation=activation,
        #                                seq_len=seq_len,
        #                                weights_init=recurrent_init)
        #rec2 = SimpleRecurrentBatchNorm(name='rec2',
        #                                dim=state_dim,
        #                                activation=activation,
        #                                seq_len=seq_len,
        #                                weights_init=recurrent_init)
        #rec3 = SimpleRecurrentBatchNorm(name='rec3',
        #                                dim=state_dim,
        #                                activation=activation,
        #                                seq_len=seq_len,
        #                                weights_init=recurrent_init)
    else:
        rec1 = LSTM(name='rec1', dim=state_dim, activation=activation,
                    weights_init=NormalizedInitialization())
        #rec1 = SimpleRecurrent(name='rec1', dim=state_dim, activation=activation,
        #                       weights_init=recurrent_init)
        #rec2 = SimpleRecurrent(name='rec2', dim=state_dim, activation=activation,
        #                       weights_init=recurrent_init)
        #rec3 = SimpleRecurrent(name='rec3', dim=state_dim, activation=activation,
        #                       weights_init=recurrent_init)
    
    rec1.initialize()
    #rec2.initialize()
    #rec3.initialize()
    
    s1 = MyRecurrent(rec1, [input_dim, state_dim, label_dim + 1],
                     activations=[Identity(), Identity()], name='s1')
    #s2 = MyRecurrent(rec2, [state_dim, state_dim, state_dim],
    #                 activations=[Identity(), Identity()], name='s2')
    #s3 = MyRecurrent(rec3, [state_dim, state_dim, label_dim + 1],
    #                 activations=[Identity(), Identity()], name='s3')

    s1.initialize()
    #s2.initialize()
    #s3.initialize()

    o1 = s1.apply(x, input_mask)
    #o2 = s2.apply(o1)
    #y_hat_o = s3.apply(o2)
    y_hat_o = o1
    
    shape = y_hat_o.shape
    y_hat = Softmax().apply(y_hat_o.reshape((-1, shape[-1]))).reshape(shape)

    y_mask = output_mask
    y_hat_mask = input_mask


    # ------------------------------------------------------------------------
    # Costs and Algorithm

    ctc_cost = T.sum(ctc.cpu_ctc_th(
         y_hat_o, T.sum(y_hat_mask, axis=0),
         y + T.ones_like(y), T.sum(y_mask, axis=0)))
    batch_cost = ctc_cost.copy(name='batch_cost')

    bs = y.shape[1]
    cost_train = aggregation.mean(batch_cost, bs).copy("sequence_cost")
    cost_per_character = aggregation.mean(batch_cost,
                                          output_mask.sum()).copy(
                                                  "character_cost")
    cg_train = ComputationGraph(cost_train)

    model = Model(cost_train)
    train_cost_per_character = aggregation.mean(cost_train,
                                                output_mask.sum()).copy(
                                                        "train_character_cost")

    algorithm = GradientDescent(step_rule=step_rule, cost=cost_train,
                                parameters=cg_train.parameters,
                                on_unused_sources='warn')



    # ------------------------------------------------------------------------
    # Monitoring and extensions

    parameters = model.get_parameter_dict()
    observed_vars = [cost_train, train_cost_per_character,
                     aggregation.mean(algorithm.total_gradient_norm)]
    for name, param in parameters.iteritems():
        observed_vars.append(param.norm(2).copy(name + "_norm"))
        observed_vars.append(algorithm.gradients[param].norm(2).copy(name + "_grad_norm"))
    train_monitor = TrainingDataMonitoring(
        variables=observed_vars,
        prefix="train", after_epoch=True)

    dev_monitor = DataStreamMonitoring(
        variables=[cost_train, cost_per_character],
        data_stream=dev_stream, prefix="dev"
    )
    train_ctc_monitor = CTCMonitoring(x, input_mask, y_hat, eol_symbol, train_stream,
                                      prefix='train', every_n_epochs=1,
                                      before_training=True,
                                      phoneme_dict=phoneme_dict,
                                      black_list=black_list, train=True)
    dev_ctc_monitor = CTCMonitoring(x, input_mask, y_hat, eol_symbol, dev_stream,
                                    prefix='dev', every_n_epochs=1,
                                    phoneme_dict=phoneme_dict,
                                    black_list=black_list)

    extensions = []
    if 'load_path' in kwargs:
        extensions.append(Load(kwargs['load_path']))

    extensions.extend([FinishAfter(after_n_epochs=epochs),
                       train_monitor,
                       dev_monitor,
                       train_ctc_monitor,
                       dev_ctc_monitor])

    if test_cost:
        test_monitor = DataStreamMonitoring(
            variables=[cost_train, cost_per_character],
            data_stream=test_stream,
            prefix="test"
        )
        test_ctc_monitor = CTCMonitoring(x, input_mask, y_hat, eol_symbol, test_stream,
                                         prefix='test', every_n_epochs=1,
                                         phoneme_dict=phoneme_dict,
                                         black_list=black_list)
        extensions.append(test_monitor)
        extensions.append(test_ctc_monitor)

    #if not os.path.exists(experiment_path):
    #    os.makedirs(experiment_path)
    #best_path = os.path.join(experiment_path, 'best/')
    #if not os.path.exists(best_path):
    #    os.mkdir(best_path)
    #best_path = os.path.join(best_path, 'model.bin')
    extensions.append(EarlyStopping(to_watch, patience, '/dev/null'))
    extensions.extend([ProgressBar(), Printing()])


    # ------------------------------------------------------------------------
    # Main Loop

    main_loop = MainLoop(model=model, data_stream=train_stream,
                         algorithm=algorithm, extensions=extensions)

    print "Building time: %f" % (time.time() - t0)
   # if write_predictions:
   #     with open('predicted.txt', 'w') as f_pred:
   #         with open('targets.txt', 'w') as f_targets:
   #             evaluator = CTCEvaluator(
   #                 eol_symbol, x, input_mask, y_hat, phoneme_dict, black_list)
   #             evaluator.evaluate(dev_stream, file_pred=f_pred,
   #                                file_targets=f_targets)
   #     return
    main_loop.run()
示例#8
0
    l_in, n_hidden, grad_clipping=grad_clip, nonlinearity=lasagne.nonlinearities.rectify, backwards=True
)
h1 = lasagne.layers.ElemwiseSumLayer([h1f, h1b])

h2f = lasagne.layers.RecurrentLayer(h1, n_hidden, grad_clipping=grad_clip, nonlinearity=lasagne.nonlinearities.rectify)
h2b = lasagne.layers.RecurrentLayer(
    h1, n_hidden, grad_clipping=grad_clip, nonlinearity=lasagne.nonlinearities.rectify, backwards=True
)
h2 = lasagne.layers.ElemwiseSumLayer([h2f, h2b])

h3 = lasagne.layers.RecurrentLayer(h2, num_classes, grad_clipping=grad_clip, nonlinearity=lasagne.nonlinearities.linear)
l_out = lasagne.layers.ReshapeLayer(h3, ((max_len, mbsz, num_classes)))

network_output = lasagne.layers.get_output(l_out)

cost = T.mean(ctc.cpu_ctc_th(network_output, input_lens, output, output_lens))
grads = T.grad(cost, wrt=network_output)
all_params = lasagne.layers.get_all_params(l_out)
updates = lasagne.updates.adam(cost, all_params, 0.001)

train = theano.function([l_in.input_var, input_lens, output, output_lens], cost, updates=updates)
predict = theano.function([l_in.input_var], network_output)
get_grad = theano.function([l_in.input_var, input_lens, output, output_lens], grads)

from loader import DataLoader

data_loader = DataLoader(mbsz=mbsz, min_len=min_len, max_len=max_len, num_classes=num_classes)

i = 1
while True:
    i += 1
示例#9
0
文件: ctc.py 项目: zhangf911/kur
	def get_loss(self, model, target, output):
		""" Returns the loss function that can be used by the implementation-
			specific model.
		"""
		backend = model.get_backend()

		if backend.get_name() == 'keras':

			import keras.backend as K

			if self.variant is None:

				# Just use the built-in Keras CTC loss function.
				logger.debug('Attaching built-in Keras CTC loss function to '
					'model output "%s".', target)

			elif self.variant == 'warp':

				# Just use the built-in Keras CTC loss function.
				logger.info('Attaching Warp-CTC loss function to model '
					'output "%s".', target)

				if backend.get_toolchain() != 'theano':
					logger.error('If you want to use warp-ctc, you need to '
						'use the Theano backend to Keras.')
					raise ValueError('Warp-CTC is currently only supported '
						'with the Theano backend to Keras.')

			else:
				raise ValueError('Unsupported variant "{}" on loss function '
					'"{}" for backend "{}".'.format(self.variant,
						self.get_name(), backend.get_name()))

			ctc_scaled = 'ctc_scaled_{}'.format(self.input_length)
			flattened_labels = 'ctc_flattened_labels_{}'.format(target)

			transcript_length = K.placeholder(
				ndim=2,
				dtype='int32',
				name=self.output_length
			)
			transcript = K.placeholder(
				ndim=2,
				dtype='int32',
				name=self.output if self.variant is None \
					else flattened_labels
			)
			utterance_length = K.placeholder(
				ndim=2,
				dtype='int32',
				name=self.input_length if self.relative_to is None \
					else ctc_scaled
			)

			if self.relative_to is not None:
				model.add_data_source(
					ctc_scaled,
					ScaledSource(
						model,
						relative_to=self.relative_to,
						to_this=target,
						scale_this=self.input_length
					)
				)

			if self.variant == 'warp':
				model.add_data_source(
					flattened_labels,
					FlattenSource(
						self.output,
						self.output_length
					)
				)

			if self.variant is None:
				out = K.ctc_batch_cost(
					transcript,
					output,
					utterance_length,
					transcript_length
				)
			else:
				import ctc						# pylint: disable=import-error
				out = ctc.cpu_ctc_th(
					output.dimshuffle((1, 0, 2)),
					K.squeeze(utterance_length, -1),
					transcript[0]+1,
					K.squeeze(transcript_length, -1)
				)

			return (
				(
					(self.output_length, transcript_length),
					(self.output if self.variant is None \
						else flattened_labels, transcript),
					(self.input_length if self.relative_to is None \
						else ctc_scaled, utterance_length)
				),
				out
			)

		else:
			raise ValueError('Unsupported backend "{}" for loss function "{}"'
				.format(backend.get_name(), self.get_name()))
def train(step_rule, label_dim, state_dim, epochs, seed, dropout, test_cost,
          experiment_path, features, weight_noise, to_watch, patience,
          batch_size, batch_norm, **kwargs):

    print '.. TIMIT experiment'
    print '.. arguments:', ' '.join(sys.argv)
    t0 = time.time()

    # ------------------------------------------------------------------------
    # Streams

    rng = np.random.RandomState(seed)
    stream_args = dict(rng=rng, batch_size=batch_size)

    print '.. initializing iterators'
    train_dataset = Timit('train', features=features)
    train_stream = construct_stream(train_dataset, **stream_args)
    dev_dataset = Timit('dev', features=features)
    dev_stream = construct_stream(dev_dataset, **stream_args)
    test_dataset = Timit('test', features=features)
    test_stream = construct_stream(test_dataset, **stream_args)
    update_stream = construct_stream(train_dataset,
                                     n_batches=100,
                                     **stream_args)

    phone_dict = train_dataset.get_phoneme_dict()
    phoneme_dict = {
        k: phone_to_phoneme_dict[v] if v in phone_to_phoneme_dict else v
        for k, v in phone_dict.iteritems()
    }
    ind_to_phoneme = {v: k for k, v in phoneme_dict.iteritems()}
    eol_symbol = ind_to_phoneme['<STOP>']

    # ------------------------------------------------------------------------
    # Graph

    print '.. building model'
    x = T.tensor3('features')
    y = T.matrix('phonemes')
    input_mask = T.matrix('features_mask')
    output_mask = T.matrix('phonemes_mask')

    theano.config.compute_test_value = 'off'
    x.tag.test_value = np.random.randn(100, 24, 123).astype(floatX)
    y.tag.test_value = np.ones((30, 24), dtype=floatX)
    input_mask.tag.test_value = np.ones((100, 24), dtype=floatX)
    output_mask.tag.test_value = np.ones((30, 24), dtype=floatX)

    seq_len = 100
    input_dim = 123
    activation = Tanh()
    recurrent_init = IdentityInit(0.99)

    rec1 = TimLSTM(not batch_norm,
                   input_dim,
                   state_dim,
                   activation,
                   name='LSTM')
    rec1.initialize()
    l1 = Linear(state_dim,
                label_dim + 1,
                name='out_linear',
                weights_init=Orthogonal(),
                biases_init=Constant(0.0))
    l1.initialize()
    o1 = rec1.apply(x)
    y_hat_o = l1.apply(o1)

    shape = y_hat_o.shape
    y_hat = Softmax().apply(y_hat_o.reshape((-1, shape[-1]))).reshape(shape)

    y_mask = output_mask
    y_hat_mask = input_mask

    # ------------------------------------------------------------------------
    # Costs and Algorithm

    ctc_cost = T.sum(
        ctc.cpu_ctc_th(y_hat_o, T.sum(y_hat_mask, axis=0), y + T.ones_like(y),
                       T.sum(y_mask, axis=0)))
    batch_cost = ctc_cost.copy(name='batch_cost')

    bs = y.shape[1]
    cost_train = aggregation.mean(batch_cost, bs).copy("sequence_cost")
    cost_per_character = aggregation.mean(
        batch_cost, output_mask.sum()).copy("character_cost")
    cg_train = ComputationGraph(cost_train)

    model = Model(cost_train)
    train_cost_per_character = aggregation.mean(
        cost_train, output_mask.sum()).copy("train_character_cost")

    algorithm = GradientDescent(step_rule=step_rule,
                                cost=cost_train,
                                parameters=cg_train.parameters,
                                on_unused_sources='warn')

    # ------------------------------------------------------------------------
    # Monitoring and extensions

    parameters = model.get_parameter_dict()
    observed_vars = [
        cost_train, train_cost_per_character,
        aggregation.mean(algorithm.total_gradient_norm)
    ]
    for name, param in parameters.iteritems():
        observed_vars.append(param.norm(2).copy(name + "_norm"))
        observed_vars.append(
            algorithm.gradients[param].norm(2).copy(name + "_grad_norm"))
    train_monitor = TrainingDataMonitoring(variables=observed_vars,
                                           prefix="train",
                                           after_epoch=True)

    dev_monitor = DataStreamMonitoring(
        variables=[cost_train, cost_per_character],
        data_stream=dev_stream,
        prefix="dev")
    train_ctc_monitor = CTCMonitoring(x,
                                      input_mask,
                                      y_hat,
                                      eol_symbol,
                                      train_stream,
                                      prefix='train',
                                      every_n_epochs=1,
                                      before_training=True,
                                      phoneme_dict=phoneme_dict,
                                      black_list=black_list,
                                      train=True)
    dev_ctc_monitor = CTCMonitoring(x,
                                    input_mask,
                                    y_hat,
                                    eol_symbol,
                                    dev_stream,
                                    prefix='dev',
                                    every_n_epochs=1,
                                    phoneme_dict=phoneme_dict,
                                    black_list=black_list)

    extensions = []
    if 'load_path' in kwargs:
        extensions.append(Load(kwargs['load_path']))

    extensions.extend([
        FinishAfter(after_n_epochs=epochs), train_monitor, dev_monitor,
        train_ctc_monitor, dev_ctc_monitor
    ])

    if test_cost:
        test_monitor = DataStreamMonitoring(
            variables=[cost_train, cost_per_character],
            data_stream=test_stream,
            prefix="test")
        test_ctc_monitor = CTCMonitoring(x,
                                         input_mask,
                                         y_hat,
                                         eol_symbol,
                                         test_stream,
                                         prefix='test',
                                         every_n_epochs=1,
                                         phoneme_dict=phoneme_dict,
                                         black_list=black_list)
        extensions.append(test_monitor)
        extensions.append(test_ctc_monitor)

    #if not os.path.exists(experiment_path):
    #    os.makedirs(experiment_path)
    #best_path = os.path.join(experiment_path, 'best/')
    #if not os.path.exists(best_path):
    #    os.mkdir(best_path)
    #best_path = os.path.join(best_path, 'model.bin')
    extensions.append(EarlyStopping(to_watch, patience, '/dev/null'))
    extensions.extend([ProgressBar(), Printing()])

    # ------------------------------------------------------------------------
    # Main Loop

    main_loop = MainLoop(model=model,
                         data_stream=train_stream,
                         algorithm=algorithm,
                         extensions=extensions)

    print "Building time: %f" % (time.time() - t0)
    # if write_predictions:
    #     with open('predicted.txt', 'w') as f_pred:
    #         with open('targets.txt', 'w') as f_targets:
    #             evaluator = CTCEvaluator(
    #                 eol_symbol, x, input_mask, y_hat, phoneme_dict, black_list)
    #             evaluator.evaluate(dev_stream, file_pred=f_pred,
    #                                file_targets=f_targets)
    #     return
    main_loop.run()
示例#11
0
    def get_loss(self, model, target, output):
        """ Returns the loss function that can be used by the implementation-
			specific model.
		"""
        backend = model.get_backend()

        if backend.get_name() == 'keras':

            import keras.backend as K

            if 'warp' in self.variant:

                # Just use the built-in Keras CTC loss function.
                logger.info(
                    'Attaching Warp-CTC loss function to model '
                    'output "%s".', target)

                if backend.get_toolchain() != 'theano':
                    logger.error('If you want to use warp-ctc, you need to '
                                 'use the Theano backend to Keras.')
                    raise ValueError('Warp-CTC is currently only supported '
                                     'with the Theano backend to Keras.')

            else:
                # Just use the built-in Keras CTC loss function.
                logger.debug(
                    'Attaching built-in Keras CTC loss function to '
                    'model output "%s".', target)

            ctc_scaled = 'ctc_scaled_{}'.format(self.input_length)
            flattened_labels = 'ctc_flattened_labels_{}'.format(target)

            transcript_length = K.placeholder(ndim=2,
                                              dtype='int32',
                                              name=self.output_length)
            transcript = K.placeholder(
             ndim=2,
             dtype='int32',
             name=flattened_labels if 'warp' in self.variant \
              else self.output
            )
            utterance_length = K.placeholder(
             ndim=2,
             dtype='int32',
             name=self.input_length if self.relative_to is None \
              else ctc_scaled
            )

            if self.relative_to is not None:
                model.add_data_source(
                    ctc_scaled,
                    ScaledSource(model,
                                 relative_to=self.relative_to,
                                 to_this=target,
                                 scale_this=self.input_length))

            if 'warp' in self.variant:
                model.add_data_source(
                    flattened_labels,
                    FlattenSource(self.output, self.output_length))

                try:
                    import ctc  # pylint: disable=import-error
                except ImportError:
                    logger.error(
                        'The warp-CTC loss function was requested,  '
                        'but we cannot find the "ctc" library. See our '
                        'troubleshooting page for helpful tips.')
                    raise ImportError(
                        'Cannot find the "ctc" library, which '
                        'is needed when using the "warp" variant of the CTC '
                        'loss function.')

                out = ctc.cpu_ctc_th(output.dimshuffle((1, 0, 2)),
                                     K.squeeze(utterance_length, -1),
                                     transcript[0] + 1,
                                     K.squeeze(transcript_length, -1))
            else:
                out = K.ctc_batch_cost(transcript, output, utterance_length,
                                       transcript_length)

            if 'loss_scale' in self.variant:
                logger.debug('Loss scaling is active.')
                out = out * K.mean(K.cast(utterance_length,
                                          K.dtype(out))) / 100

            return (
             (
              (self.output_length, transcript_length),
              (flattened_labels if 'warp' in self.variant \
               else self.output, transcript),
              (self.input_length if self.relative_to is None \
               else ctc_scaled, utterance_length)
             ),
             out
            )

        elif backend.get_name() == 'pytorch':

            if 'warp' not in self.variant:
                logger.error(
                    'PyTorch does not include a native CTC loss '
                    'function yet. However, PyTorch bindings to Warp CTC are '
                    'available (SeanNaren/warp-ctc). Try installing that, and '
                    'then settings variant=warp.')
                raise ValueError('Only Warp CTC is supported for PyTorch '
                                 'right now.')

            ctc_scaled = 'ctc_scaled_{}'.format(self.input_length)
            flattened_labels = 'ctc_flattened_labels_{}'.format(target)
            transcript_length = model.data.placeholder(self.output_length,
                                                       location='cpu',
                                                       data_type='int')
            transcript = model.data.placeholder(flattened_labels,
                                                location='cpu',
                                                data_type='int')
            utterance_length = model.data.placeholder(
                self.input_length if self.relative_to is None else ctc_scaled,
                location='cpu',
                data_type='int')

            if self.relative_to is not None:
                model.add_data_source(
                    ctc_scaled,
                    ScaledSource(model,
                                 relative_to=self.relative_to,
                                 to_this=target,
                                 scale_this=self.input_length))

            if 'warp' in self.variant:
                model.add_data_source(
                    flattened_labels,
                    FlattenSource(self.output, self.output_length))

            try:
                from warpctc_pytorch import CTCLoss  # pytorch: disable=import-error
            except ImportError:
                logger.error(
                    'The warp-CTC loss function was requested,  '
                    'but we cannot find the "warpctc_pytorch" library. See '
                    'out troubleshooting page for helpful tips.')
                raise ImportError(
                    'Cannot find the "warpctc_pytorch" library, '
                    'which is needed when using the "warp" variant of the CTC '
                    'loss function.')

            loss = model.data.move(CTCLoss())

            def basic_ctc_loss(inputs, output):
                """ Computes CTC loss.
				"""
                return loss(
                    output.transpose(1, 0).contiguous(),
                    inputs[0][0] + 1,  # transcript[0]+1
                    inputs[1].squeeze(1),  # K.squeeze(utterance_length, -1),
                    inputs[2].squeeze(1)  # K.squeeze(transcript_length, -1)
                ) / output.size(0)

            if 'loss_scale' in self.variant:
                logger.debug('Loss scaling is active.')

                def loss_scale(inputs, output):
                    """ Computes CTC loss.
					"""
                    factor = inputs[1].float().mean().data[0] / 100.
                    return basic_ctc_loss(inputs, output) * factor

                get_ctc_loss = loss_scale
            else:
                get_ctc_loss = basic_ctc_loss

            return [
             [
              (flattened_labels if 'warp' in self.variant \
               else self.output, transcript),
              (self.input_length if self.relative_to is None \
               else ctc_scaled, utterance_length),
              (self.output_length, transcript_length)
             ],
             get_ctc_loss
            ]

        else:
            raise ValueError(
                'Unsupported backend "{}" for loss function "{}"'.format(
                    backend.get_name(), self.get_name()))
示例#12
0
文件: ctc.py 项目: Navdevl/kur
	def get_loss(self, model, target, output):
		""" Returns the loss function that can be used by the implementation-
			specific model.
		"""
		backend = model.get_backend()

		if backend.get_name() == 'keras':

			import keras.backend as K

			if 'warp' in self.variant:

				# Just use the built-in Keras CTC loss function.
				logger.info('Attaching Warp-CTC loss function to model '
					'output "%s".', target)

				if backend.get_toolchain() != 'theano':
					logger.error('If you want to use warp-ctc, you need to '
						'use the Theano backend to Keras.')
					raise ValueError('Warp-CTC is currently only supported '
						'with the Theano backend to Keras.')

			else:
				# Just use the built-in Keras CTC loss function.
				logger.debug('Attaching built-in Keras CTC loss function to '
					'model output "%s".', target)

			ctc_scaled = 'ctc_scaled_{}'.format(self.input_length)
			flattened_labels = 'ctc_flattened_labels_{}'.format(target)

			transcript_length = K.placeholder(
				ndim=2,
				dtype='int32',
				name=self.output_length
			)
			transcript = K.placeholder(
				ndim=2,
				dtype='int32',
				name=flattened_labels if 'warp' in self.variant \
					else self.output
			)
			utterance_length = K.placeholder(
				ndim=2,
				dtype='int32',
				name=self.input_length if self.relative_to is None \
					else ctc_scaled
			)

			if self.relative_to is not None:
				model.add_data_source(
					ctc_scaled,
					ScaledSource(
						model,
						relative_to=self.relative_to,
						to_this=target,
						scale_this=self.input_length
					)
				)

			if 'warp' in self.variant:
				model.add_data_source(
					flattened_labels,
					FlattenSource(
						self.output,
						self.output_length
					)
				)

				try:
					import ctc					# pylint: disable=import-error
				except ImportError:
					logger.error('The warp-CTC loss function was requested,  '
						'but we cannot find the "ctc" library. See our '
						'troubleshooting page for helpful tips.')
					raise ImportError('Cannot find the "ctc" library, which '
						'is needed when using the "warp" variant of the CTC '
						'loss function.')

				out = ctc.cpu_ctc_th(
					output.dimshuffle((1, 0, 2)),
					K.squeeze(utterance_length, -1),
					transcript[0]+1,
					K.squeeze(transcript_length, -1)
				)
			else:
				out = K.ctc_batch_cost(
					transcript,
					output,
					utterance_length,
					transcript_length
				)

			if 'loss_scale' in self.variant:
				logger.debug('Loss scaling is active.')
				out = out * K.mean(
					K.cast(utterance_length, K.dtype(out))
				) / 100

			return (
				(
					(self.output_length, transcript_length),
					(flattened_labels if 'warp' in self.variant \
						else self.output, transcript),
					(self.input_length if self.relative_to is None \
						else ctc_scaled, utterance_length)
				),
				out
			)

		elif backend.get_name() == 'pytorch':

			if 'warp' not in self.variant:
				logger.error('PyTorch does not include a native CTC loss '
					'function yet. However, PyTorch bindings to Warp CTC are '
					'available (SeanNaren/warp-ctc). Try installing that, and '
					'then settings variant=warp.')
				raise ValueError('Only Warp CTC is supported for PyTorch '
					'right now.')

			ctc_scaled = 'ctc_scaled_{}'.format(self.input_length)
			flattened_labels = 'ctc_flattened_labels_{}'.format(target)
			transcript_length = model.data.placeholder(
				self.output_length,
				location='cpu',
				data_type='int'
			)
			transcript = model.data.placeholder(
				flattened_labels,
				location='cpu',
				data_type='int'
			)
			utterance_length = model.data.placeholder(
				self.input_length if self.relative_to is None else ctc_scaled,
				location='cpu',
				data_type='int'
			)

			if self.relative_to is not None:
				model.add_data_source(
					ctc_scaled,
					ScaledSource(
						model,
						relative_to=self.relative_to,
						to_this=target,
						scale_this=self.input_length
					)
				)

			if 'warp' in self.variant:
				model.add_data_source(
					flattened_labels,
					FlattenSource(
						self.output,
						self.output_length
					)
				)

			try:
				from warpctc_pytorch import CTCLoss	# pytorch: disable=import-error
			except ImportError:
				logger.error('The warp-CTC loss function was requested,  '
					'but we cannot find the "warpctc_pytorch" library. See '
					'out troubleshooting page for helpful tips.')
				raise ImportError('Cannot find the "warpctc_pytorch" library, '
					'which is needed when using the "warp" variant of the CTC '
					'loss function.')

			loss = model.data.move(CTCLoss())

			def basic_ctc_loss(inputs, output):
				""" Computes CTC loss.
				"""
				return loss(
					output.transpose(1, 0).contiguous(),
					inputs[0][0]+1,		# transcript[0]+1
					inputs[1].squeeze(1),	# K.squeeze(utterance_length, -1),
					inputs[2].squeeze(1)	# K.squeeze(transcript_length, -1)
				) / output.size(0)

			if 'loss_scale' in self.variant:
				logger.debug('Loss scaling is active.')

				def loss_scale(inputs, output):
					""" Computes CTC loss.
					"""
					factor = inputs[1].float().mean().data[0] / 100.
					return basic_ctc_loss(inputs, output) * factor

				get_ctc_loss = loss_scale
			else:
				get_ctc_loss = basic_ctc_loss

			return [
				[
					(flattened_labels if 'warp' in self.variant \
						else self.output, transcript),
					(self.input_length if self.relative_to is None \
						else ctc_scaled, utterance_length),
					(self.output_length, transcript_length)
				],
				get_ctc_loss
			]

		else:
			raise ValueError('Unsupported backend "{}" for loss function "{}"'
				.format(backend.get_name(), self.get_name()))