def test_ctc_greedy_decode(self): ''' ctc greedy decode unittest''' decode_result_list = py_ctc.ctc_greedy_decode( self.model_output, 0, unique=False) self.assertEqual(decode_result_list, [[2, 2], [3, 1]]) decode_result_list = py_ctc.ctc_greedy_decode( self.model_output, 0, unique=True) self.assertEqual(decode_result_list, [[2], [3, 1]])
def eval(self): ''' only eval''' mode = utils.EVAL #get eval dataset # data must be init before model build eval_ds, eval_task = self.input_data(mode=mode) eval_gen = tf.data.make_one_shot_iterator(eval_ds) #get eval model self.model_fn(mode=mode) assert self._built #load model eval_func = self.get_metric_func() target_seq_list, predict_seq_list = [], [] for _ in range(len(eval_task)): batch_data = K.get_session().run(eval_gen.get_next()[0]) batch_input = batch_data['inputs'] batch_target = batch_data['targets'].tolist() batch_predict = eval_func(batch_input)[0] batch_decode = py_ctc.ctc_greedy_decode(batch_predict, 0, unique=True) target_seq_list += batch_target predict_seq_list += batch_decode token_errors = metrics_lib.token_error( predict_seq_list=predict_seq_list, target_seq_list=target_seq_list, eos_id=0) logging.info("eval finish!") logging.info("Token Error: {}".format(token_errors))
def infer(self, yield_single_examples=False): ''' only for infer ''' #load data mode = utils.INFER # data must be init before model build infer_ds, infer_task = self.input_data(mode=mode) infer_gen = tf.data.make_one_shot_iterator(infer_ds) self.model_fn(mode=mode) assert self._built #load model infer_func = self.get_metric_func() for _ in range(len(infer_task)): batch_data = K.get_session().run(infer_gen.get_next()[0]) batch_input = batch_data['inputs'] batch_uttid = batch_data['uttids'].tolist() batch_predict = infer_func(batch_input)[0] batch_decode = py_ctc.ctc_greedy_decode(batch_predict, 0, unique=True) for utt_index, uttid in enumerate(batch_uttid): logging.info("utt ID: {}".format(uttid)) logging.info("infer result: {}".format( batch_decode[utt_index]))
def on_epoch_end(self, epoch, logs={}): '''computing token error''' cur_session = K.get_session() target_seq_list, predict_seq_list = [], [] is_py_sequence = True if isinstance(self.eval_ds, (dataset_ops.DatasetV2, dataset_ops.DatasetV1)): eval_gen = self.eval_ds.make_one_shot_iterator() self.next_batch_gen = eval_gen.get_next()[0] is_py_sequence = False elif isinstance(self.eval_ds, (iterator_ops.IteratorV2, iterator_ops.Iterator)): self.next_batch_gen = self.ds.get_next()[0] is_py_sequence = False for index in range(len(self.eval_task)): batch_data = None if is_py_sequence: batch_data = self.eval_ds[index][0] else: batch_data = cur_session.run(self.next_batch_gen) batch_input = batch_data['inputs'] batch_target = batch_data['targets'].tolist() batch_predict = self.func(batch_input)[0] if self.decoder_type == 'argmax': predict_seq_list += py_ctc.ctc_greedy_decode(batch_predict, 0, unique=True) else: sequence_lens = [ len(pre_sequence) for pre_sequence in batch_predict ] batch_decoder, _ = tf_ctc.ctc_beam_search_decode( tf.constant(batch_predict), tf.constant(sequence_lens), beam_width=3, top_paths=3) predict_seq_list += cur_session.run(batch_decoder)[0].tolist() target_seq_list += batch_target val_token_errors = metrics_lib.token_error( predict_seq_list=predict_seq_list, target_seq_list=target_seq_list, eos_id=0) logs['val_token_err'] = val_token_errors if 'val_loss' in logs: logging.info("Epoch {}: on eval, val_loss is {}.".format( epoch + 1, logs['val_loss'])) logging.info("Epoch {}: on eval, token_err is {}.".format( epoch + 1, val_token_errors)) logging.info("Epoch {}: loss on train is {}".format( epoch + 1, logs['loss']))