def compile_train_fn(model, learning_rate=2e-4): """ Build the CTC training routine for speech models. Args: model: A keras model (built=True) instance Returns: train_fn (theano.function): Function that takes in acoustic inputs, and updates the model. Returns network outputs and ctc cost """ logger.info("Building train_fn") acoustic_input = model.inputs[0] network_output = model.outputs[0] output_lens = K.placeholder(ndim=1, dtype='int32') label = K.placeholder(ndim=1, dtype='int32') label_lens = K.placeholder(ndim=1, dtype='int32') network_output = network_output.dimshuffle((1, 0, 2)) ctc_cost = ctc.cpu_ctc_th(network_output, output_lens, label, label_lens).mean() trainable_vars = model.trainable_weights # optimizer = SGD(nesterov=True, lr=learning_rate, momentum=0.9, # clipnorm=100) # updates = optimizer.get_updates(trainable_vars, [], ctc_cost) trainable_vars = model.trainable_weights grads = K.gradients(ctc_cost, trainable_vars) grads = lasagne.updates.total_norm_constraint(grads, 100) updates = lasagne.updates.nesterov_momentum(grads, trainable_vars, learning_rate, 0.99) train_fn = K.function( [acoustic_input, output_lens, label, label_lens, K.learning_phase()], [network_output, ctc_cost], updates=updates) return train_fn
def __init__(self, model): self._yhat = T.vector(name='yhat', dtype='int32') self._ylen = T.vector(name='ylen', dtype='int32') self._model = model self._loss = ctc.cpu_ctc_th(model.output, model.output_size, self._yhat, self._ylen).sum() self._callf = theano.function(inputs=self.input, outputs=self.output)
def create_theano_func(): acts = T.ftensor3() act_lens = T.ivector() labels = T.ivector() label_lens = T.ivector() costs = cpu_ctc_th(acts, act_lens, labels, label_lens) cost = T.mean(costs) grads = T.grad(cost, acts) f = theano.function([acts, act_lens, labels, label_lens], cost, allow_input_downcast=True) g = theano.function([acts, act_lens, labels, label_lens], grads, allow_input_downcast=True) return f, g
def compile_test_fn(model): """ Build a testing routine for speech models. Args: model: A keras model (built=True) instance Returns: val_fn (theano.function): Function that takes in acoustic inputs, and calculates the loss. Returns network outputs and ctc cost """ logger.info("Building val_fn") acoustic_input = model.inputs[0] network_output = model.outputs[0] output_lens = K.placeholder(ndim=1, dtype='int32') label = K.placeholder(ndim=1, dtype='int32') label_lens = K.placeholder(ndim=1, dtype='int32') network_output = network_output.dimshuffle((1, 0, 2)) ctc_cost = ctc.cpu_ctc_th(network_output, output_lens, label, label_lens).mean() val_fn = K.function([acoustic_input, output_lens, label, label_lens, K.learning_phase()], [network_output, ctc_cost]) return val_fn
h1, n_hidden, grad_clipping=grad_clip, nonlinearity=lasagne.nonlinearities.rectify, backwards=True) h2 = lasagne.layers.ElemwiseSumLayer([h2f, h2b]) h3 = lasagne.layers.RecurrentLayer(h2, num_classes, grad_clipping=grad_clip, nonlinearity=lasagne.nonlinearities.linear) l_out = lasagne.layers.ReshapeLayer(h3, ((max_len, mbsz, num_classes))) network_output = lasagne.layers.get_output(l_out) cost = T.mean(ctc.cpu_ctc_th(network_output, input_lens, output, output_lens)) grads = T.grad(cost, wrt=network_output) all_params = lasagne.layers.get_all_params(l_out) updates = lasagne.updates.adam(cost, all_params, 0.001) train = theano.function([l_in.input_var, input_lens, output, output_lens], cost, updates=updates) predict = theano.function([l_in.input_var], network_output) get_grad = theano.function([l_in.input_var, input_lens, output, output_lens], grads) from loader import DataLoader data_loader = DataLoader(mbsz=mbsz, min_len=min_len,
def train(step_rule, label_dim, state_dim, epochs, seed, dropout, test_cost, experiment_path, features, weight_noise, to_watch, patience, batch_size, batch_norm, **kwargs): print '.. TIMIT experiment' print '.. arguments:', ' '.join(sys.argv) t0 = time.time() # ------------------------------------------------------------------------ # Streams rng = np.random.RandomState(seed) stream_args = dict(rng=rng, batch_size=batch_size) print '.. initializing iterators' train_dataset = Timit('train', features=features) train_stream = construct_stream(train_dataset, **stream_args) dev_dataset = Timit('dev', features=features) dev_stream = construct_stream(dev_dataset, **stream_args) test_dataset = Timit('test', features=features) test_stream = construct_stream(test_dataset, **stream_args) update_stream = construct_stream(train_dataset, n_batches=100, **stream_args) phone_dict = train_dataset.get_phoneme_dict() phoneme_dict = {k: phone_to_phoneme_dict[v] if v in phone_to_phoneme_dict else v for k, v in phone_dict.iteritems()} ind_to_phoneme = {v: k for k, v in phoneme_dict.iteritems()} eol_symbol = ind_to_phoneme['<STOP>'] # ------------------------------------------------------------------------ # Graph print '.. building model' x = T.tensor3('features') y = T.matrix('phonemes') input_mask = T.matrix('features_mask') output_mask = T.matrix('phonemes_mask') theano.config.compute_test_value = 'off' x.tag.test_value = np.random.randn(100, 24, 123).astype(floatX) y.tag.test_value = np.ones((30, 24), dtype=floatX) input_mask.tag.test_value = np.ones((100, 24), dtype=floatX) output_mask.tag.test_value = np.ones((30, 24), dtype=floatX) seq_len = 100 input_dim = 123 activation = Tanh() recurrent_init = IdentityInit(0.99) if batch_norm: rec1 = LSTMBatchNorm(name='rec1', dim=state_dim, activation=activation, weights_init=NormalizedInitialization()) #rec1 = SimpleRecurrentBatchNorm(name='rec1', # dim=state_dim, # activation=activation, # seq_len=seq_len, # weights_init=recurrent_init) #rec2 = SimpleRecurrentBatchNorm(name='rec2', # dim=state_dim, # activation=activation, # seq_len=seq_len, # weights_init=recurrent_init) #rec3 = SimpleRecurrentBatchNorm(name='rec3', # dim=state_dim, # activation=activation, # seq_len=seq_len, # weights_init=recurrent_init) else: rec1 = LSTM(name='rec1', dim=state_dim, activation=activation, weights_init=NormalizedInitialization()) #rec1 = SimpleRecurrent(name='rec1', dim=state_dim, activation=activation, # weights_init=recurrent_init) #rec2 = SimpleRecurrent(name='rec2', dim=state_dim, activation=activation, # weights_init=recurrent_init) #rec3 = SimpleRecurrent(name='rec3', dim=state_dim, activation=activation, # weights_init=recurrent_init) rec1.initialize() #rec2.initialize() #rec3.initialize() s1 = MyRecurrent(rec1, [input_dim, state_dim, label_dim + 1], activations=[Identity(), Identity()], name='s1') #s2 = MyRecurrent(rec2, [state_dim, state_dim, state_dim], # activations=[Identity(), Identity()], name='s2') #s3 = MyRecurrent(rec3, [state_dim, state_dim, label_dim + 1], # activations=[Identity(), Identity()], name='s3') s1.initialize() #s2.initialize() #s3.initialize() o1 = s1.apply(x, input_mask) #o2 = s2.apply(o1) #y_hat_o = s3.apply(o2) y_hat_o = o1 shape = y_hat_o.shape y_hat = Softmax().apply(y_hat_o.reshape((-1, shape[-1]))).reshape(shape) y_mask = output_mask y_hat_mask = input_mask # ------------------------------------------------------------------------ # Costs and Algorithm ctc_cost = T.sum(ctc.cpu_ctc_th( y_hat_o, T.sum(y_hat_mask, axis=0), y + T.ones_like(y), T.sum(y_mask, axis=0))) batch_cost = ctc_cost.copy(name='batch_cost') bs = y.shape[1] cost_train = aggregation.mean(batch_cost, bs).copy("sequence_cost") cost_per_character = aggregation.mean(batch_cost, output_mask.sum()).copy( "character_cost") cg_train = ComputationGraph(cost_train) model = Model(cost_train) train_cost_per_character = aggregation.mean(cost_train, output_mask.sum()).copy( "train_character_cost") algorithm = GradientDescent(step_rule=step_rule, cost=cost_train, parameters=cg_train.parameters, on_unused_sources='warn') # ------------------------------------------------------------------------ # Monitoring and extensions parameters = model.get_parameter_dict() observed_vars = [cost_train, train_cost_per_character, aggregation.mean(algorithm.total_gradient_norm)] for name, param in parameters.iteritems(): observed_vars.append(param.norm(2).copy(name + "_norm")) observed_vars.append(algorithm.gradients[param].norm(2).copy(name + "_grad_norm")) train_monitor = TrainingDataMonitoring( variables=observed_vars, prefix="train", after_epoch=True) dev_monitor = DataStreamMonitoring( variables=[cost_train, cost_per_character], data_stream=dev_stream, prefix="dev" ) train_ctc_monitor = CTCMonitoring(x, input_mask, y_hat, eol_symbol, train_stream, prefix='train', every_n_epochs=1, before_training=True, phoneme_dict=phoneme_dict, black_list=black_list, train=True) dev_ctc_monitor = CTCMonitoring(x, input_mask, y_hat, eol_symbol, dev_stream, prefix='dev', every_n_epochs=1, phoneme_dict=phoneme_dict, black_list=black_list) extensions = [] if 'load_path' in kwargs: extensions.append(Load(kwargs['load_path'])) extensions.extend([FinishAfter(after_n_epochs=epochs), train_monitor, dev_monitor, train_ctc_monitor, dev_ctc_monitor]) if test_cost: test_monitor = DataStreamMonitoring( variables=[cost_train, cost_per_character], data_stream=test_stream, prefix="test" ) test_ctc_monitor = CTCMonitoring(x, input_mask, y_hat, eol_symbol, test_stream, prefix='test', every_n_epochs=1, phoneme_dict=phoneme_dict, black_list=black_list) extensions.append(test_monitor) extensions.append(test_ctc_monitor) #if not os.path.exists(experiment_path): # os.makedirs(experiment_path) #best_path = os.path.join(experiment_path, 'best/') #if not os.path.exists(best_path): # os.mkdir(best_path) #best_path = os.path.join(best_path, 'model.bin') extensions.append(EarlyStopping(to_watch, patience, '/dev/null')) extensions.extend([ProgressBar(), Printing()]) # ------------------------------------------------------------------------ # Main Loop main_loop = MainLoop(model=model, data_stream=train_stream, algorithm=algorithm, extensions=extensions) print "Building time: %f" % (time.time() - t0) # if write_predictions: # with open('predicted.txt', 'w') as f_pred: # with open('targets.txt', 'w') as f_targets: # evaluator = CTCEvaluator( # eol_symbol, x, input_mask, y_hat, phoneme_dict, black_list) # evaluator.evaluate(dev_stream, file_pred=f_pred, # file_targets=f_targets) # return main_loop.run()
l_in, n_hidden, grad_clipping=grad_clip, nonlinearity=lasagne.nonlinearities.rectify, backwards=True ) h1 = lasagne.layers.ElemwiseSumLayer([h1f, h1b]) h2f = lasagne.layers.RecurrentLayer(h1, n_hidden, grad_clipping=grad_clip, nonlinearity=lasagne.nonlinearities.rectify) h2b = lasagne.layers.RecurrentLayer( h1, n_hidden, grad_clipping=grad_clip, nonlinearity=lasagne.nonlinearities.rectify, backwards=True ) h2 = lasagne.layers.ElemwiseSumLayer([h2f, h2b]) h3 = lasagne.layers.RecurrentLayer(h2, num_classes, grad_clipping=grad_clip, nonlinearity=lasagne.nonlinearities.linear) l_out = lasagne.layers.ReshapeLayer(h3, ((max_len, mbsz, num_classes))) network_output = lasagne.layers.get_output(l_out) cost = T.mean(ctc.cpu_ctc_th(network_output, input_lens, output, output_lens)) grads = T.grad(cost, wrt=network_output) all_params = lasagne.layers.get_all_params(l_out) updates = lasagne.updates.adam(cost, all_params, 0.001) train = theano.function([l_in.input_var, input_lens, output, output_lens], cost, updates=updates) predict = theano.function([l_in.input_var], network_output) get_grad = theano.function([l_in.input_var, input_lens, output, output_lens], grads) from loader import DataLoader data_loader = DataLoader(mbsz=mbsz, min_len=min_len, max_len=max_len, num_classes=num_classes) i = 1 while True: i += 1
def get_loss(self, model, target, output): """ Returns the loss function that can be used by the implementation- specific model. """ backend = model.get_backend() if backend.get_name() == 'keras': import keras.backend as K if self.variant is None: # Just use the built-in Keras CTC loss function. logger.debug('Attaching built-in Keras CTC loss function to ' 'model output "%s".', target) elif self.variant == 'warp': # Just use the built-in Keras CTC loss function. logger.info('Attaching Warp-CTC loss function to model ' 'output "%s".', target) if backend.get_toolchain() != 'theano': logger.error('If you want to use warp-ctc, you need to ' 'use the Theano backend to Keras.') raise ValueError('Warp-CTC is currently only supported ' 'with the Theano backend to Keras.') else: raise ValueError('Unsupported variant "{}" on loss function ' '"{}" for backend "{}".'.format(self.variant, self.get_name(), backend.get_name())) ctc_scaled = 'ctc_scaled_{}'.format(self.input_length) flattened_labels = 'ctc_flattened_labels_{}'.format(target) transcript_length = K.placeholder( ndim=2, dtype='int32', name=self.output_length ) transcript = K.placeholder( ndim=2, dtype='int32', name=self.output if self.variant is None \ else flattened_labels ) utterance_length = K.placeholder( ndim=2, dtype='int32', name=self.input_length if self.relative_to is None \ else ctc_scaled ) if self.relative_to is not None: model.add_data_source( ctc_scaled, ScaledSource( model, relative_to=self.relative_to, to_this=target, scale_this=self.input_length ) ) if self.variant == 'warp': model.add_data_source( flattened_labels, FlattenSource( self.output, self.output_length ) ) if self.variant is None: out = K.ctc_batch_cost( transcript, output, utterance_length, transcript_length ) else: import ctc # pylint: disable=import-error out = ctc.cpu_ctc_th( output.dimshuffle((1, 0, 2)), K.squeeze(utterance_length, -1), transcript[0]+1, K.squeeze(transcript_length, -1) ) return ( ( (self.output_length, transcript_length), (self.output if self.variant is None \ else flattened_labels, transcript), (self.input_length if self.relative_to is None \ else ctc_scaled, utterance_length) ), out ) else: raise ValueError('Unsupported backend "{}" for loss function "{}"' .format(backend.get_name(), self.get_name()))
def train(step_rule, label_dim, state_dim, epochs, seed, dropout, test_cost, experiment_path, features, weight_noise, to_watch, patience, batch_size, batch_norm, **kwargs): print '.. TIMIT experiment' print '.. arguments:', ' '.join(sys.argv) t0 = time.time() # ------------------------------------------------------------------------ # Streams rng = np.random.RandomState(seed) stream_args = dict(rng=rng, batch_size=batch_size) print '.. initializing iterators' train_dataset = Timit('train', features=features) train_stream = construct_stream(train_dataset, **stream_args) dev_dataset = Timit('dev', features=features) dev_stream = construct_stream(dev_dataset, **stream_args) test_dataset = Timit('test', features=features) test_stream = construct_stream(test_dataset, **stream_args) update_stream = construct_stream(train_dataset, n_batches=100, **stream_args) phone_dict = train_dataset.get_phoneme_dict() phoneme_dict = { k: phone_to_phoneme_dict[v] if v in phone_to_phoneme_dict else v for k, v in phone_dict.iteritems() } ind_to_phoneme = {v: k for k, v in phoneme_dict.iteritems()} eol_symbol = ind_to_phoneme['<STOP>'] # ------------------------------------------------------------------------ # Graph print '.. building model' x = T.tensor3('features') y = T.matrix('phonemes') input_mask = T.matrix('features_mask') output_mask = T.matrix('phonemes_mask') theano.config.compute_test_value = 'off' x.tag.test_value = np.random.randn(100, 24, 123).astype(floatX) y.tag.test_value = np.ones((30, 24), dtype=floatX) input_mask.tag.test_value = np.ones((100, 24), dtype=floatX) output_mask.tag.test_value = np.ones((30, 24), dtype=floatX) seq_len = 100 input_dim = 123 activation = Tanh() recurrent_init = IdentityInit(0.99) rec1 = TimLSTM(not batch_norm, input_dim, state_dim, activation, name='LSTM') rec1.initialize() l1 = Linear(state_dim, label_dim + 1, name='out_linear', weights_init=Orthogonal(), biases_init=Constant(0.0)) l1.initialize() o1 = rec1.apply(x) y_hat_o = l1.apply(o1) shape = y_hat_o.shape y_hat = Softmax().apply(y_hat_o.reshape((-1, shape[-1]))).reshape(shape) y_mask = output_mask y_hat_mask = input_mask # ------------------------------------------------------------------------ # Costs and Algorithm ctc_cost = T.sum( ctc.cpu_ctc_th(y_hat_o, T.sum(y_hat_mask, axis=0), y + T.ones_like(y), T.sum(y_mask, axis=0))) batch_cost = ctc_cost.copy(name='batch_cost') bs = y.shape[1] cost_train = aggregation.mean(batch_cost, bs).copy("sequence_cost") cost_per_character = aggregation.mean( batch_cost, output_mask.sum()).copy("character_cost") cg_train = ComputationGraph(cost_train) model = Model(cost_train) train_cost_per_character = aggregation.mean( cost_train, output_mask.sum()).copy("train_character_cost") algorithm = GradientDescent(step_rule=step_rule, cost=cost_train, parameters=cg_train.parameters, on_unused_sources='warn') # ------------------------------------------------------------------------ # Monitoring and extensions parameters = model.get_parameter_dict() observed_vars = [ cost_train, train_cost_per_character, aggregation.mean(algorithm.total_gradient_norm) ] for name, param in parameters.iteritems(): observed_vars.append(param.norm(2).copy(name + "_norm")) observed_vars.append( algorithm.gradients[param].norm(2).copy(name + "_grad_norm")) train_monitor = TrainingDataMonitoring(variables=observed_vars, prefix="train", after_epoch=True) dev_monitor = DataStreamMonitoring( variables=[cost_train, cost_per_character], data_stream=dev_stream, prefix="dev") train_ctc_monitor = CTCMonitoring(x, input_mask, y_hat, eol_symbol, train_stream, prefix='train', every_n_epochs=1, before_training=True, phoneme_dict=phoneme_dict, black_list=black_list, train=True) dev_ctc_monitor = CTCMonitoring(x, input_mask, y_hat, eol_symbol, dev_stream, prefix='dev', every_n_epochs=1, phoneme_dict=phoneme_dict, black_list=black_list) extensions = [] if 'load_path' in kwargs: extensions.append(Load(kwargs['load_path'])) extensions.extend([ FinishAfter(after_n_epochs=epochs), train_monitor, dev_monitor, train_ctc_monitor, dev_ctc_monitor ]) if test_cost: test_monitor = DataStreamMonitoring( variables=[cost_train, cost_per_character], data_stream=test_stream, prefix="test") test_ctc_monitor = CTCMonitoring(x, input_mask, y_hat, eol_symbol, test_stream, prefix='test', every_n_epochs=1, phoneme_dict=phoneme_dict, black_list=black_list) extensions.append(test_monitor) extensions.append(test_ctc_monitor) #if not os.path.exists(experiment_path): # os.makedirs(experiment_path) #best_path = os.path.join(experiment_path, 'best/') #if not os.path.exists(best_path): # os.mkdir(best_path) #best_path = os.path.join(best_path, 'model.bin') extensions.append(EarlyStopping(to_watch, patience, '/dev/null')) extensions.extend([ProgressBar(), Printing()]) # ------------------------------------------------------------------------ # Main Loop main_loop = MainLoop(model=model, data_stream=train_stream, algorithm=algorithm, extensions=extensions) print "Building time: %f" % (time.time() - t0) # if write_predictions: # with open('predicted.txt', 'w') as f_pred: # with open('targets.txt', 'w') as f_targets: # evaluator = CTCEvaluator( # eol_symbol, x, input_mask, y_hat, phoneme_dict, black_list) # evaluator.evaluate(dev_stream, file_pred=f_pred, # file_targets=f_targets) # return main_loop.run()
def get_loss(self, model, target, output): """ Returns the loss function that can be used by the implementation- specific model. """ backend = model.get_backend() if backend.get_name() == 'keras': import keras.backend as K if 'warp' in self.variant: # Just use the built-in Keras CTC loss function. logger.info( 'Attaching Warp-CTC loss function to model ' 'output "%s".', target) if backend.get_toolchain() != 'theano': logger.error('If you want to use warp-ctc, you need to ' 'use the Theano backend to Keras.') raise ValueError('Warp-CTC is currently only supported ' 'with the Theano backend to Keras.') else: # Just use the built-in Keras CTC loss function. logger.debug( 'Attaching built-in Keras CTC loss function to ' 'model output "%s".', target) ctc_scaled = 'ctc_scaled_{}'.format(self.input_length) flattened_labels = 'ctc_flattened_labels_{}'.format(target) transcript_length = K.placeholder(ndim=2, dtype='int32', name=self.output_length) transcript = K.placeholder( ndim=2, dtype='int32', name=flattened_labels if 'warp' in self.variant \ else self.output ) utterance_length = K.placeholder( ndim=2, dtype='int32', name=self.input_length if self.relative_to is None \ else ctc_scaled ) if self.relative_to is not None: model.add_data_source( ctc_scaled, ScaledSource(model, relative_to=self.relative_to, to_this=target, scale_this=self.input_length)) if 'warp' in self.variant: model.add_data_source( flattened_labels, FlattenSource(self.output, self.output_length)) try: import ctc # pylint: disable=import-error except ImportError: logger.error( 'The warp-CTC loss function was requested, ' 'but we cannot find the "ctc" library. See our ' 'troubleshooting page for helpful tips.') raise ImportError( 'Cannot find the "ctc" library, which ' 'is needed when using the "warp" variant of the CTC ' 'loss function.') out = ctc.cpu_ctc_th(output.dimshuffle((1, 0, 2)), K.squeeze(utterance_length, -1), transcript[0] + 1, K.squeeze(transcript_length, -1)) else: out = K.ctc_batch_cost(transcript, output, utterance_length, transcript_length) if 'loss_scale' in self.variant: logger.debug('Loss scaling is active.') out = out * K.mean(K.cast(utterance_length, K.dtype(out))) / 100 return ( ( (self.output_length, transcript_length), (flattened_labels if 'warp' in self.variant \ else self.output, transcript), (self.input_length if self.relative_to is None \ else ctc_scaled, utterance_length) ), out ) elif backend.get_name() == 'pytorch': if 'warp' not in self.variant: logger.error( 'PyTorch does not include a native CTC loss ' 'function yet. However, PyTorch bindings to Warp CTC are ' 'available (SeanNaren/warp-ctc). Try installing that, and ' 'then settings variant=warp.') raise ValueError('Only Warp CTC is supported for PyTorch ' 'right now.') ctc_scaled = 'ctc_scaled_{}'.format(self.input_length) flattened_labels = 'ctc_flattened_labels_{}'.format(target) transcript_length = model.data.placeholder(self.output_length, location='cpu', data_type='int') transcript = model.data.placeholder(flattened_labels, location='cpu', data_type='int') utterance_length = model.data.placeholder( self.input_length if self.relative_to is None else ctc_scaled, location='cpu', data_type='int') if self.relative_to is not None: model.add_data_source( ctc_scaled, ScaledSource(model, relative_to=self.relative_to, to_this=target, scale_this=self.input_length)) if 'warp' in self.variant: model.add_data_source( flattened_labels, FlattenSource(self.output, self.output_length)) try: from warpctc_pytorch import CTCLoss # pytorch: disable=import-error except ImportError: logger.error( 'The warp-CTC loss function was requested, ' 'but we cannot find the "warpctc_pytorch" library. See ' 'out troubleshooting page for helpful tips.') raise ImportError( 'Cannot find the "warpctc_pytorch" library, ' 'which is needed when using the "warp" variant of the CTC ' 'loss function.') loss = model.data.move(CTCLoss()) def basic_ctc_loss(inputs, output): """ Computes CTC loss. """ return loss( output.transpose(1, 0).contiguous(), inputs[0][0] + 1, # transcript[0]+1 inputs[1].squeeze(1), # K.squeeze(utterance_length, -1), inputs[2].squeeze(1) # K.squeeze(transcript_length, -1) ) / output.size(0) if 'loss_scale' in self.variant: logger.debug('Loss scaling is active.') def loss_scale(inputs, output): """ Computes CTC loss. """ factor = inputs[1].float().mean().data[0] / 100. return basic_ctc_loss(inputs, output) * factor get_ctc_loss = loss_scale else: get_ctc_loss = basic_ctc_loss return [ [ (flattened_labels if 'warp' in self.variant \ else self.output, transcript), (self.input_length if self.relative_to is None \ else ctc_scaled, utterance_length), (self.output_length, transcript_length) ], get_ctc_loss ] else: raise ValueError( 'Unsupported backend "{}" for loss function "{}"'.format( backend.get_name(), self.get_name()))
def get_loss(self, model, target, output): """ Returns the loss function that can be used by the implementation- specific model. """ backend = model.get_backend() if backend.get_name() == 'keras': import keras.backend as K if 'warp' in self.variant: # Just use the built-in Keras CTC loss function. logger.info('Attaching Warp-CTC loss function to model ' 'output "%s".', target) if backend.get_toolchain() != 'theano': logger.error('If you want to use warp-ctc, you need to ' 'use the Theano backend to Keras.') raise ValueError('Warp-CTC is currently only supported ' 'with the Theano backend to Keras.') else: # Just use the built-in Keras CTC loss function. logger.debug('Attaching built-in Keras CTC loss function to ' 'model output "%s".', target) ctc_scaled = 'ctc_scaled_{}'.format(self.input_length) flattened_labels = 'ctc_flattened_labels_{}'.format(target) transcript_length = K.placeholder( ndim=2, dtype='int32', name=self.output_length ) transcript = K.placeholder( ndim=2, dtype='int32', name=flattened_labels if 'warp' in self.variant \ else self.output ) utterance_length = K.placeholder( ndim=2, dtype='int32', name=self.input_length if self.relative_to is None \ else ctc_scaled ) if self.relative_to is not None: model.add_data_source( ctc_scaled, ScaledSource( model, relative_to=self.relative_to, to_this=target, scale_this=self.input_length ) ) if 'warp' in self.variant: model.add_data_source( flattened_labels, FlattenSource( self.output, self.output_length ) ) try: import ctc # pylint: disable=import-error except ImportError: logger.error('The warp-CTC loss function was requested, ' 'but we cannot find the "ctc" library. See our ' 'troubleshooting page for helpful tips.') raise ImportError('Cannot find the "ctc" library, which ' 'is needed when using the "warp" variant of the CTC ' 'loss function.') out = ctc.cpu_ctc_th( output.dimshuffle((1, 0, 2)), K.squeeze(utterance_length, -1), transcript[0]+1, K.squeeze(transcript_length, -1) ) else: out = K.ctc_batch_cost( transcript, output, utterance_length, transcript_length ) if 'loss_scale' in self.variant: logger.debug('Loss scaling is active.') out = out * K.mean( K.cast(utterance_length, K.dtype(out)) ) / 100 return ( ( (self.output_length, transcript_length), (flattened_labels if 'warp' in self.variant \ else self.output, transcript), (self.input_length if self.relative_to is None \ else ctc_scaled, utterance_length) ), out ) elif backend.get_name() == 'pytorch': if 'warp' not in self.variant: logger.error('PyTorch does not include a native CTC loss ' 'function yet. However, PyTorch bindings to Warp CTC are ' 'available (SeanNaren/warp-ctc). Try installing that, and ' 'then settings variant=warp.') raise ValueError('Only Warp CTC is supported for PyTorch ' 'right now.') ctc_scaled = 'ctc_scaled_{}'.format(self.input_length) flattened_labels = 'ctc_flattened_labels_{}'.format(target) transcript_length = model.data.placeholder( self.output_length, location='cpu', data_type='int' ) transcript = model.data.placeholder( flattened_labels, location='cpu', data_type='int' ) utterance_length = model.data.placeholder( self.input_length if self.relative_to is None else ctc_scaled, location='cpu', data_type='int' ) if self.relative_to is not None: model.add_data_source( ctc_scaled, ScaledSource( model, relative_to=self.relative_to, to_this=target, scale_this=self.input_length ) ) if 'warp' in self.variant: model.add_data_source( flattened_labels, FlattenSource( self.output, self.output_length ) ) try: from warpctc_pytorch import CTCLoss # pytorch: disable=import-error except ImportError: logger.error('The warp-CTC loss function was requested, ' 'but we cannot find the "warpctc_pytorch" library. See ' 'out troubleshooting page for helpful tips.') raise ImportError('Cannot find the "warpctc_pytorch" library, ' 'which is needed when using the "warp" variant of the CTC ' 'loss function.') loss = model.data.move(CTCLoss()) def basic_ctc_loss(inputs, output): """ Computes CTC loss. """ return loss( output.transpose(1, 0).contiguous(), inputs[0][0]+1, # transcript[0]+1 inputs[1].squeeze(1), # K.squeeze(utterance_length, -1), inputs[2].squeeze(1) # K.squeeze(transcript_length, -1) ) / output.size(0) if 'loss_scale' in self.variant: logger.debug('Loss scaling is active.') def loss_scale(inputs, output): """ Computes CTC loss. """ factor = inputs[1].float().mean().data[0] / 100. return basic_ctc_loss(inputs, output) * factor get_ctc_loss = loss_scale else: get_ctc_loss = basic_ctc_loss return [ [ (flattened_labels if 'warp' in self.variant \ else self.output, transcript), (self.input_length if self.relative_to is None \ else ctc_scaled, utterance_length), (self.output_length, transcript_length) ], get_ctc_loss ] else: raise ValueError('Unsupported backend "{}" for loss function "{}"' .format(backend.get_name(), self.get_name()))