Beispiel #1
0
    def check_loading(self, num_gpu, is_sorted):
        print('----- num_gpu: ' + str(num_gpu) + ', is_sorted: ' +
              str(is_sorted) + ' -----')

        batch_size = 64
        dataset = Dataset(data_type='train',
                          label_type_main='character',
                          label_type_sub='phone61',
                          batch_size=batch_size,
                          num_stack=3,
                          num_skip=3,
                          is_sorted=is_sorted,
                          is_progressbar=True,
                          num_gpu=num_gpu)

        tf.reset_default_graph()
        with tf.Session().as_default() as sess:
            print('=> Loading mini-batch...')
            map_file_path_char = '../metrics/mapping_files/ctc/char2num.txt'
            map_file_path_phone = '../metrics/mapping_files/ctc/phone2num_61.txt'

            mini_batch = dataset.next_batch(session=sess)

            iter_per_epoch = int(dataset.data_num / (batch_size * num_gpu)) + 1
            for i in range(iter_per_epoch + 1):
                return_tuple = mini_batch.__next__()
                inputs = return_tuple[0]
                labels_char_st = return_tuple[1]
                labels_phone_st = return_tuple[2]

                if num_gpu > 1:
                    for inputs_gpu in inputs:
                        print(inputs_gpu.shape)
                    labels_char_st = labels_char_st[0]
                    labels_phone_st = labels_phone_st[0]

                labels_char = sparsetensor2list(labels_char_st,
                                                batch_size=len(inputs))
                labels_phone = sparsetensor2list(labels_phone_st,
                                                 batch_size=len(inputs))

                if num_gpu == 1:
                    for inputs_i, labels_i in zip(inputs, labels_char):
                        if len(inputs_i) < len(labels_i):
                            print(len(inputs_i))
                            print(len(labels_i))
                            raise ValueError
                    for inputs_i, labels_i in zip(inputs, labels_phone):
                        if len(inputs_i) < len(labels_i):
                            print(len(inputs_i))
                            print(len(labels_i))
                            raise ValueError

                str_true_char = num2char(labels_char[0], map_file_path_char)
                str_true_char = re.sub(r'_', ' ', str_true_char)
                str_true_phone = num2phone(labels_phone[0],
                                           map_file_path_phone)
def do_plot(model, params, epoch, eval_batch_size):
    """Plot the multi-task CTC posteriors.
    Args:
        model: the model to restore
        params (dict): A dictionary of parameters
        epoch (int): the epoch to restore
        eval_batch_size (int): the size of mini-batch in evaluation
    """
    # Load dataset
    test_data = Dataset(data_type='test',
                        label_type_main=params['label_type_main'],
                        label_type_sub=params['label_type_sub'],
                        batch_size=eval_batch_size,
                        splice=params['splice'],
                        num_stack=params['num_stack'],
                        num_skip=params['num_skip'],
                        sort_utt=False,
                        progressbar=True)

    # Define placeholders
    model.create_placeholders()

    # Add to the graph each operation (including model definition)
    _, logits_main, logits_sub = model.compute_loss(
        model.inputs_pl_list[0], model.labels_pl_list[0],
        model.labels_sub_pl_list[0], model.inputs_seq_len_pl_list[0],
        model.keep_prob_pl_list[0])
    posteriors_op_main, posteriors_op_sub = model.posteriors(
        logits_main, logits_sub)

    # Create a saver for writing training checkpoints
    saver = tf.train.Saver()

    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(model.save_path)

        # If check point exists
        if ckpt:
            # Use last saved model
            model_path = ckpt.model_checkpoint_path
            if epoch != -1:
                model_path = model_path.split('/')[:-1]
                model_path = '/'.join(model_path) + '/model.ckpt-' + str(epoch)
            saver.restore(sess, model_path)
            print("Model restored: " + model_path)
        else:
            raise ValueError('There are not any checkpoints.')

        plot(session=sess,
             posteriors_op_main=posteriors_op_main,
             posteriors_op_sub=posteriors_op_sub,
             model=model,
             dataset=test_data,
             label_type_main=params['label_type_main'],
             label_type_sub=params['label_type_sub'],
             num_stack=params['num_stack'],
             save_path=mkdir_join(save_path, 'ctc_output'),
             show=False)
    def check(self,
              label_type_main,
              data_type='dev',
              shuffle=False,
              sort_utt=False,
              sort_stop_epoch=None,
              frame_stacking=False,
              splice=1):

        print('========================================')
        print('  label_type_main: %s' % label_type_main)
        print('  data_type: %s' % data_type)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('========================================')

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(data_type=data_type,
                          label_type_main=label_type_main,
                          label_type_sub='phone61',
                          batch_size=64,
                          max_epoch=2,
                          splice=splice,
                          num_stack=num_stack,
                          num_skip=num_skip,
                          shuffle=shuffle,
                          sort_utt=sort_utt,
                          sort_stop_epoch=sort_stop_epoch,
                          progressbar=True)

        print('=> Loading mini-batch...')
        idx2char = Idx2char(map_file_path='../../metrics/mapping_files/' +
                            label_type_main + '.txt')
        idx2phone = Idx2phone(
            map_file_path='../../metrics/mapping_files/phone61.txt')

        for data, is_new_epoch in dataset:
            inputs, labels_char, labels_phone, inputs_seq_len, input_names = data

            if data_type != 'test':
                str_true_char = idx2char(labels_char[0][0])
                str_true_phone = idx2phone(labels_phone[0][0])
            else:
                str_true_char = labels_char[0][0][0]
                str_true_phone = labels_phone[0][0][0]

            print('----- %s ----- (epoch: %.3f)' %
                  (input_names[0][0], dataset.epoch_detail))
            print(str_true_char)
            print(str_true_phone)
Beispiel #4
0
def do_eval(network, param, epoch=None):
    """Evaluate the model.
    Args:
        network: model to restore
        param: A dictionary of parameters
        epoch: int, the epoch to restore
    """
    # Load dataset
    test_data = Dataset(data_type='test',
                        label_type_main='character',
                        label_type_sub='phone39',
                        batch_size=1,
                        num_stack=param['num_stack'],
                        num_skip=param['num_skip'],
                        is_sorted=False,
                        is_progressbar=True)

    # Define placeholders
    network.inputs = tf.placeholder(tf.float32,
                                    shape=[None, None, network.input_size],
                                    name='input')
    indices_pl = tf.placeholder(tf.int64, name='indices')
    values_pl = tf.placeholder(tf.int32, name='values')
    shape_pl = tf.placeholder(tf.int64, name='shape')
    network.labels = tf.SparseTensor(indices_pl, values_pl, shape_pl)
    indices_sub_pl = tf.placeholder(tf.int64, name='indices_sub')
    values_sub_pl = tf.placeholder(tf.int32, name='values_sub')
    shape_sub_pl = tf.placeholder(tf.int64, name='shape_sub')
    network.labels_sub = tf.SparseTensor(indices_sub_pl, values_sub_pl,
                                         shape_sub_pl)
    network.inputs_seq_len = tf.placeholder(tf.int64,
                                            shape=[None],
                                            name='inputs_seq_len')

    # Add to the graph each operation
    _, logits_main, logits_sub = network.compute_loss(network.inputs,
                                                      network.labels,
                                                      network.labels_sub,
                                                      network.inputs_seq_len)
    decode_op_main, decode_op_sub = network.decoder(logits_main,
                                                    logits_sub,
                                                    network.inputs_seq_len,
                                                    decode_type='beam_search',
                                                    beam_width=20)
    per_op_main, per_op_sub = network.compute_ler(decode_op_main,
                                                  decode_op_sub,
                                                  network.labels,
                                                  network.labels_sub)

    # Create a saver for writing training checkpoints
    saver = tf.train.Saver()

    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(network.model_dir)

        # If check point exists
        if ckpt:
            # Use last saved model
            model_path = ckpt.model_checkpoint_path
            if epoch is not None:
                model_path = model_path.split('/')[:-1]
                model_path = '/'.join(model_path) + '/model.ckpt-' + str(epoch)
            saver.restore(sess, model_path)
            print("Model restored: " + model_path)
        else:
            raise ValueError('There are not any checkpoints.')

        print('=== Test Data Evaluation ===')
        cer_test = do_eval_cer(session=sess,
                               decode_op=decode_op_main,
                               network=network,
                               dataset=test_data,
                               is_progressbar=True,
                               is_multitask=True)
        print('  CER: %f %%' % (cer_test * 100))

        per_test = do_eval_per(session=sess,
                               decode_op=decode_op_sub,
                               per_op=per_op_sub,
                               network=network,
                               dataset=test_data,
                               train_label_type=param['label_type_sub'],
                               is_progressbar=True,
                               is_multitask=True)
        print('  PER: %f %%' % (per_test * 100))
def do_decode(model, params, epoch, beam_width, eval_batch_size):
    """Decode the Multi-task CTC outputs.
    Args:
        model: the model to restore
        params (dict): A dictionary of parameters
        epoch (int): the epoch to restore
        beam_width (int): beam width for beam search.
            1 disables beam search, which mean greedy decoding.
        eval_batch_size (int): the size of mini-batch when evaluation
    """
    # Load dataset
    test_data = Dataset(data_type='test',
                        label_type_main=params['label_type_main'],
                        label_type_sub=params['label_type_sub'],
                        batch_size=eval_batch_size,
                        splice=params['splice'],
                        num_stack=params['num_stack'],
                        num_skip=params['num_skip'],
                        shuffle=False,
                        progressbar=True)

    # Define placeholders
    model.create_placeholders()

    # Add to the graph each operation (including model definition)
    _, logits_main, logits_sub = model.compute_loss(
        model.inputs_pl_list[0], model.labels_pl_list[0],
        model.labels_sub_pl_list[0], model.inputs_seq_len_pl_list[0],
        model.keep_prob_pl_list[0])
    decode_op_main, decode_op_sub = model.decoder(
        logits_main,
        logits_sub,
        model.inputs_seq_len_pl_list[0],
        beam_width=beam_width)

    # Create a saver for writing training checkpoints
    saver = tf.train.Saver()

    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(model.save_path)

        # If check point exists
        if ckpt:
            # Use last saved model
            model_path = ckpt.model_checkpoint_path
            if epoch != -1:
                model_path = model_path.split('/')[:-1]
                model_path = '/'.join(model_path) + '/model.ckpt-' + str(epoch)
            saver.restore(sess, model_path)
            print("Model restored: " + model_path)
        else:
            raise ValueError('There are not any checkpoints.')

        # Visualize
        decode(session=sess,
               decode_op_main=decode_op_main,
               decode_op_sub=decode_op_sub,
               model=model,
               dataset=test_data,
               label_type_main=params['label_type_main'],
               label_type_sub=params['label_type_sub'],
               is_test=True,
               save_path=None)
def do_train(network, param):
    """Run multi-task training. The target labels in the main task is
    characters and those in the sub task is 61 phones. The model is
    evaluated by CER and PER with 39 phones.
    Args:
        network: network to train
        param: A dictionary of parameters
    """
    # Load dataset
    train_data = Dataset(data_type='train',
                         label_type_main='character',
                         label_type_sub=param['label_type_sub'],
                         batch_size=param['batch_size'],
                         num_stack=param['num_stack'],
                         num_skip=param['num_skip'],
                         is_sorted=True)
    dev_data = Dataset(data_type='dev',
                       label_type_main='character',
                       label_type_sub=param['label_type_sub'],
                       batch_size=param['batch_size'],
                       num_stack=param['num_stack'],
                       num_skip=param['num_skip'],
                       is_sorted=False)
    test_data = Dataset(data_type='test',
                        label_type_main='character',
                        label_type_sub='phone39',
                        batch_size=1,
                        num_stack=param['num_stack'],
                        num_skip=param['num_skip'],
                        is_sorted=False)

    # Tell TensorFlow that the model will be built into the default graph
    with tf.Graph().as_default():

        # Define placeholders
        network.inputs = tf.placeholder(tf.float32,
                                        shape=[None, None, network.input_size],
                                        name='input')
        indices_pl = tf.placeholder(tf.int64, name='indices')
        values_pl = tf.placeholder(tf.int32, name='values')
        shape_pl = tf.placeholder(tf.int64, name='shape')
        network.labels = tf.SparseTensor(indices_pl, values_pl, shape_pl)
        indices_sub_pl = tf.placeholder(tf.int64, name='indices_sub')
        values_sub_pl = tf.placeholder(tf.int32, name='values_sub')
        shape_sub_pl = tf.placeholder(tf.int64, name='shape_sub')
        network.labels_sub = tf.SparseTensor(indices_sub_pl, values_sub_pl,
                                             shape_sub_pl)
        network.inputs_seq_len = tf.placeholder(tf.int64,
                                                shape=[None],
                                                name='inputs_seq_len')
        network.keep_prob_input = tf.placeholder(tf.float32,
                                                 name='keep_prob_input')
        network.keep_prob_hidden = tf.placeholder(tf.float32,
                                                  name='keep_prob_hidden')

        # Add to the graph each operation
        loss_op, logits_main, logits_sub = network.compute_loss(
            network.inputs, network.labels, network.labels_sub,
            network.inputs_seq_len, network.keep_prob_input,
            network.keep_prob_hidden)
        train_op = network.train(loss_op,
                                 optimizer=param['optimizer'],
                                 learning_rate_init=float(
                                     param['learning_rate']),
                                 decay_steps=param['decay_steps'],
                                 decay_rate=param['decay_rate'])
        decode_op_main, decode_op_sub = network.decoder(
            logits_main,
            logits_sub,
            network.inputs_seq_len,
            decode_type='beam_search',
            beam_width=20)
        ler_op_main, ler_op_sub = network.compute_ler(decode_op_main,
                                                      decode_op_sub,
                                                      network.labels,
                                                      network.labels_sub)

        # Build the summary tensor based on the TensorFlow collection of
        # summaries
        summary_train = tf.summary.merge(network.summaries_train)
        summary_dev = tf.summary.merge(network.summaries_dev)

        # Add the variable initializer operation
        init_op = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints
        saver = tf.train.Saver(max_to_keep=None)

        # Count total parameters
        parameters_dict, total_parameters = count_total_parameters(
            tf.trainable_variables())
        for parameter_name in sorted(parameters_dict.keys()):
            print("%s %d" % (parameter_name, parameters_dict[parameter_name]))
        print("Total %d variables, %s M parameters" %
              (len(parameters_dict.keys()), "{:,}".format(
                  total_parameters / 1000000)))

        csv_steps, csv_loss_train, csv_loss_dev = [], [], []
        csv_cer_train, csv_cer_dev = [], []
        csv_per_train, csv_per_dev = [], []
        # Create a session for running operation on the graph
        with tf.Session() as sess:

            # Instantiate a SummaryWriter to output summaries and the graph
            summary_writer = tf.summary.FileWriter(network.model_dir,
                                                   sess.graph)

            # Initialize parameters
            sess.run(init_op)

            # Make mini-batch generator
            mini_batch_train = train_data.next_batch()
            mini_batch_dev = dev_data.next_batch()

            # Train model
            iter_per_epoch = int(train_data.data_num / param['batch_size'])
            train_step = train_data.data_num / param['batch_size']
            if train_step != int(train_step):
                iter_per_epoch += 1
            max_steps = iter_per_epoch * param['num_epoch']
            start_time_train = time.time()
            start_time_epoch = time.time()
            start_time_step = time.time()
            cer_dev_best = 1
            for step in range(max_steps):

                # Create feed dictionary for next mini batch (train)
                with tf.device('/cpu:0'):
                    inputs, labels_char, labels_phone, inputs_seq_len, _ = mini_batch_train.__next__(
                    )
                feed_dict_train = {
                    network.inputs:
                    inputs,
                    network.labels:
                    list2sparsetensor(labels_char, padded_value=-1),
                    network.labels_sub:
                    list2sparsetensor(labels_phone, padded_value=-1),
                    network.inputs_seq_len:
                    inputs_seq_len,
                    network.keep_prob_input:
                    network.dropout_ratio_input,
                    network.keep_prob_hidden:
                    network.dropout_ratio_hidden
                }

                # Update parameters
                sess.run(train_op, feed_dict=feed_dict_train)

                if (step + 1) % 10 == 0:

                    # Create feed dictionary for next mini batch (dev)
                    with tf.device('/cpu:0'):
                        inputs, labels_char, labels_phone, inputs_seq_len, _ = mini_batch_dev.__next__(
                        )
                    feed_dict_dev = {
                        network.inputs:
                        inputs,
                        network.labels:
                        list2sparsetensor(labels_char, padded_value=-1),
                        network.labels_sub:
                        list2sparsetensor(labels_phone, padded_value=-1),
                        network.inputs_seq_len:
                        inputs_seq_len,
                        network.keep_prob_input:
                        network.dropout_ratio_input,
                        network.keep_prob_hidden:
                        network.dropout_ratio_hidden
                    }

                    # Compute loss
                    loss_train = sess.run(loss_op, feed_dict=feed_dict_train)
                    loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev)
                    csv_steps.append(step)
                    csv_loss_train.append(loss_train)
                    csv_loss_dev.append(loss_dev)

                    # Change to evaluation mode
                    feed_dict_train[network.keep_prob_input] = 1.0
                    feed_dict_train[network.keep_prob_hidden] = 1.0
                    feed_dict_dev[network.keep_prob_input] = 1.0
                    feed_dict_dev[network.keep_prob_hidden] = 1.0

                    # Compute accuracy & update event file
                    cer_train, per_train, summary_str_train = sess.run(
                        [ler_op_main, ler_op_sub, summary_train],
                        feed_dict=feed_dict_train)
                    cer_dev, per_dev, summary_str_dev = sess.run(
                        [ler_op_main, ler_op_sub, summary_dev],
                        feed_dict=feed_dict_dev)
                    csv_cer_train.append(cer_train)
                    csv_cer_dev.append(cer_dev)
                    csv_per_train.append(per_train)
                    csv_per_dev.append(per_dev)
                    summary_writer.add_summary(summary_str_train, step + 1)
                    summary_writer.add_summary(summary_str_dev, step + 1)
                    summary_writer.flush()

                    duration_step = time.time() - start_time_step
                    print(
                        "Step % d: loss = %.3f (%.3f) / cer = %.4f (%.4f) / per = % .4f (%.4f) (%.3f min)"
                        % (step + 1, loss_train, loss_dev, cer_train, cer_dev,
                           per_train, per_dev, duration_step / 60))
                    sys.stdout.flush()
                    start_time_step = time.time()

                # Save checkpoint and evaluate model per epoch
                if (step + 1) % iter_per_epoch == 0 or (step + 1) == max_steps:
                    duration_epoch = time.time() - start_time_epoch
                    epoch = (step + 1) // iter_per_epoch
                    print('-----EPOCH:%d (%.3f min)-----' %
                          (epoch, duration_epoch / 60))

                    # Save model (check point)
                    checkpoint_file = join(network.model_dir, 'model.ckpt')
                    save_path = saver.save(sess,
                                           checkpoint_file,
                                           global_step=epoch)
                    print("Model saved in file: %s" % save_path)

                    if epoch >= 10:
                        start_time_eval = time.time()
                        print('=== Dev Data Evaluation ===')
                        cer_dev_epoch = do_eval_cer(session=sess,
                                                    decode_op=decode_op_main,
                                                    network=network,
                                                    dataset=dev_data,
                                                    eval_batch_size=1,
                                                    is_multitask=True)
                        print('  CER: %f %%' % (cer_dev_epoch * 100))
                        per_dev_epoch = do_eval_per(
                            session=sess,
                            decode_op=decode_op_sub,
                            per_op=ler_op_sub,
                            network=network,
                            dataset=dev_data,
                            label_type=param['label_type_sub'],
                            eval_batch_size=1,
                            is_multitask=True)
                        print('  PER: %f %%' % (per_dev_epoch * 100))

                        if cer_dev_epoch < cer_dev_best:
                            cer_dev_best = cer_dev_epoch
                            print('■■■ ↑Best Score (CER)↑ ■■■')

                            print('=== Test Data Evaluation ===')
                            cer_test = do_eval_cer(session=sess,
                                                   decode_op=decode_op_main,
                                                   network=network,
                                                   dataset=test_data,
                                                   eval_batch_size=1,
                                                   is_multitask=True)
                            print('  CER: %f %%' % (cer_test * 100))
                            per_test = do_eval_per(
                                session=sess,
                                decode_op=decode_op_sub,
                                per_op=ler_op_sub,
                                network=network,
                                dataset=test_data,
                                label_type=param['label_type_sub'],
                                eval_batch_size=1,
                                is_multitask=True)
                            print('  PER: %f %%' % (per_test * 100))

                        duration_eval = time.time() - start_time_eval
                        print('Evaluation time: %.3f min' %
                              (duration_eval / 60))

                        start_time_epoch = time.time()
                        start_time_step = time.time()

            duration_train = time.time() - start_time_train
            print('Total time: %.3f hour' % (duration_train / 3600))

            # Save train & dev loss, ler
            save_loss(csv_steps,
                      csv_loss_train,
                      csv_loss_dev,
                      save_path=network.model_dir)
            save_ler(csv_steps,
                     csv_cer_train,
                     csv_cer_dev,
                     save_path=network.model_dir)
            save_ler(csv_steps,
                     csv_per_train,
                     csv_per_dev,
                     save_path=network.model_dir)

            # Training was finished correctly
            with open(join(network.model_dir, 'complete.txt'), 'w') as f:
                f.write('')
def do_train(model, params):
    """Run multi-task CTC training. The target labels in the main task is
    characters and those in the sub task is 61 phones. The model is
    evaluated by CER and PER with 39 phones.
    Args:
        model: the model to train
        params (dict): A dictionary of parameters
    """
    # Load dataset
    train_data = Dataset(data_type='train',
                         label_type_main=params['label_type_main'],
                         label_type_sub=params['label_type_sub'],
                         batch_size=params['batch_size'],
                         max_epoch=params['num_epoch'],
                         splice=params['splice'],
                         num_stack=params['num_stack'],
                         num_skip=params['num_skip'],
                         sort_utt=True,
                         sort_stop_epoch=params['sort_stop_epoch'])
    dev_data = Dataset(data_type='dev',
                       label_type_main=params['label_type_main'],
                       label_type_sub=params['label_type_sub'],
                       batch_size=params['batch_size'],
                       splice=params['splice'],
                       num_stack=params['num_stack'],
                       num_skip=params['num_skip'],
                       sort_utt=False)
    test_data = Dataset(data_type='test',
                        label_type_main=params['label_type_main'],
                        label_type_sub='phone39',
                        batch_size=1,
                        splice=params['splice'],
                        num_stack=params['num_stack'],
                        num_skip=params['num_skip'],
                        sort_utt=False)

    # Tell TensorFlow that the model will be built into the default graph
    with tf.Graph().as_default():

        # Define placeholders
        model.create_placeholders()
        learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

        # Add to the graph each operation
        loss_op, logits_main, logits_sub = model.compute_loss(
            model.inputs_pl_list[0], model.labels_pl_list[0],
            model.labels_sub_pl_list[0], model.inputs_seq_len_pl_list[0],
            model.keep_prob_input_pl_list[0],
            model.keep_prob_hidden_pl_list[0],
            model.keep_prob_output_pl_list[0])
        train_op = model.train(loss_op,
                               optimizer=params['optimizer'],
                               learning_rate=learning_rate_pl)
        decode_op_character, decode_op_phone = model.decoder(
            logits_main,
            logits_sub,
            model.inputs_seq_len_pl_list[0],
            beam_width=params['beam_width'])
        cer_op, per_op = model.compute_ler(decode_op_character,
                                           decode_op_phone,
                                           model.labels_pl_list[0],
                                           model.labels_sub_pl_list[0])

        # Define learning rate controller
        lr_controller = Controller(
            learning_rate_init=params['learning_rate'],
            decay_start_epoch=params['decay_start_epoch'],
            decay_rate=params['decay_rate'],
            decay_patient_epoch=params['decay_patient_epoch'],
            lower_better=True)

        # Build the summary tensor based on the TensorFlow collection of
        # summaries
        summary_train = tf.summary.merge(model.summaries_train)
        summary_dev = tf.summary.merge(model.summaries_dev)

        # Add the variable initializer operation
        init_op = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints
        saver = tf.train.Saver(max_to_keep=None)

        # Count total parameters
        parameters_dict, total_parameters = count_total_parameters(
            tf.trainable_variables())
        for parameter_name in sorted(parameters_dict.keys()):
            print("%s %d" % (parameter_name, parameters_dict[parameter_name]))
        print("Total %d variables, %s M parameters" %
              (len(parameters_dict.keys()), "{:,}".format(
                  total_parameters / 1000000)))

        csv_steps, csv_loss_train, csv_loss_dev = [], [], []
        csv_cer_train, csv_cer_dev = [], []
        csv_per_train, csv_per_dev = [], []
        # Create a session for running operation on the graph
        with tf.Session() as sess:

            # Instantiate a SummaryWriter to output summaries and the graph
            summary_writer = tf.summary.FileWriter(model.save_path, sess.graph)

            # Initialize parameters
            sess.run(init_op)

            # Train model
            start_time_train = time.time()
            start_time_epoch = time.time()
            start_time_step = time.time()
            cer_dev_best = 1
            learning_rate = float(params['learning_rate'])
            for step, (data, is_new_epoch) in enumerate(train_data):

                # Create feed dictionary for next mini batch (train)
                inputs, labels_char, labels_phone, inputs_seq_len, _ = data
                feed_dict_train = {
                    model.inputs_pl_list[0]:
                    inputs,
                    model.labels_pl_list[0]:
                    list2sparsetensor(labels_char,
                                      padded_value=train_data.padded_value),
                    model.labels_sub_pl_list[0]:
                    list2sparsetensor(labels_phone,
                                      padded_value=train_data.padded_value),
                    model.inputs_seq_len_pl_list[0]:
                    inputs_seq_len,
                    model.keep_prob_input_pl_list[0]:
                    params['dropout_input'],
                    model.keep_prob_hidden_pl_list[0]:
                    params['dropout_hidden'],
                    model.keep_prob_output_pl_list[0]:
                    params['dropout_output'],
                    learning_rate_pl:
                    learning_rate
                }

                # Update parameters
                sess.run(train_op, feed_dict=feed_dict_train)

                if (step + 1) % params['print_step'] == 0:

                    # Create feed dictionary for next mini batch (dev)
                    (inputs, labels_char, labels_phone, inputs_seq_len,
                     _), _ = dev_data.next()
                    feed_dict_dev = {
                        model.inputs_pl_list[0]:
                        inputs,
                        model.labels_pl_list[0]:
                        list2sparsetensor(labels_char,
                                          padded_value=dev_data.padded_value),
                        model.labels_sub_pl_list[0]:
                        list2sparsetensor(labels_phone,
                                          padded_value=dev_data.padded_value),
                        model.inputs_seq_len_pl_list[0]:
                        inputs_seq_len,
                        model.keep_prob_input_pl_list[0]:
                        1.0,
                        model.keep_prob_hidden_pl_list[0]:
                        1.0,
                        model.keep_prob_output_pl_list[0]:
                        1.0
                    }

                    # Compute loss
                    loss_train = sess.run(loss_op, feed_dict=feed_dict_train)
                    loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev)
                    csv_steps.append(step)
                    csv_loss_train.append(loss_train)
                    csv_loss_dev.append(loss_dev)

                    # Change to evaluation mode
                    feed_dict_train[model.keep_prob_input_pl_list[0]] = 1.0
                    feed_dict_train[model.keep_prob_hidden_pl_list[0]] = 1.0
                    feed_dict_train[model.keep_prob_output_pl_list[0]] = 1.0

                    # Compute accuracy & update event files
                    cer_train, per_train, summary_str_train = sess.run(
                        [cer_op, per_op, summary_train],
                        feed_dict=feed_dict_train)
                    cer_dev, per_dev, summary_str_dev = sess.run(
                        [cer_op, per_op, summary_dev], feed_dict=feed_dict_dev)
                    csv_cer_train.append(cer_train)
                    csv_cer_dev.append(cer_dev)
                    csv_per_train.append(per_train)
                    csv_per_dev.append(per_dev)
                    summary_writer.add_summary(summary_str_train, step + 1)
                    summary_writer.add_summary(summary_str_dev, step + 1)
                    summary_writer.flush()

                    duration_step = time.time() - start_time_step
                    print(
                        "Step %d (epoch: %.3f): loss = %.3f (%.3f) / cer = %.3f (%.3f) / per = % .3f (%.3f) / lr = %.5f (%.3f min)"
                        % (step + 1, train_data.epoch_detail, loss_train,
                           loss_dev, cer_train, cer_dev, per_train, per_dev,
                           learning_rate, duration_step / 60))
                    sys.stdout.flush()
                    start_time_step = time.time()

                # Save checkpoint and evaluate model per epoch
                if is_new_epoch:
                    duration_epoch = time.time() - start_time_epoch
                    print('-----EPOCH:%d (%.3f min)-----' %
                          (train_data.epoch, duration_epoch / 60))

                    # Save fugure of loss & ler
                    plot_loss(csv_loss_train,
                              csv_loss_dev,
                              csv_steps,
                              save_path=model.save_path)
                    plot_ler(csv_cer_train,
                             csv_cer_dev,
                             csv_steps,
                             label_type=params['label_type_main'],
                             save_path=model.save_path)
                    plot_ler(csv_per_train,
                             csv_per_dev,
                             csv_steps,
                             label_type=params['label_type_sub'],
                             save_path=model.save_path)

                    if train_data.epoch >= params['eval_start_epoch']:
                        start_time_eval = time.time()
                        print('=== Dev Data Evaluation ===')
                        cer_dev_epoch, wer_dev_epoch = do_eval_cer(
                            session=sess,
                            decode_op=decode_op_character,
                            model=model,
                            dataset=dev_data,
                            label_type=params['label_type_main'],
                            eval_batch_size=1,
                            is_multitask=True)
                        print('  WER: %f %%' % (wer_dev_epoch * 100))
                        print('  CER: %f %%' % (cer_dev_epoch * 100))
                        per_dev_epoch = do_eval_per(
                            session=sess,
                            decode_op=decode_op_phone,
                            per_op=per_op,
                            model=model,
                            dataset=dev_data,
                            label_type=params['label_type_sub'],
                            eval_batch_size=1,
                            is_multitask=True)
                        print('  PER: %f %%' % (per_dev_epoch * 100))

                        if cer_dev_epoch < cer_dev_best:
                            cer_dev_best = cer_dev_epoch
                            print('■■■ ↑Best Score (CER)↑ ■■■')

                            # Save model only when best accuracy is obtained
                            # (check point)
                            checkpoint_file = join(model.save_path,
                                                   'model.ckpt')
                            save_path = saver.save(
                                sess,
                                checkpoint_file,
                                global_step=train_data.epoch)
                            print("Model saved in file: %s" % save_path)

                            print('=== Test Data Evaluation ===')
                            cer_test, wer_test = do_eval_cer(
                                session=sess,
                                decode_op=decode_op_character,
                                model=model,
                                dataset=test_data,
                                label_type=params['label_type_main'],
                                eval_batch_size=1,
                                is_multitask=True)
                            print('  WER: %f %%' % (wer_test * 100))
                            print('  CER: %f %%' % (cer_test * 100))
                            per_test = do_eval_per(
                                session=sess,
                                decode_op=decode_op_phone,
                                per_op=per_op,
                                model=model,
                                dataset=test_data,
                                label_type=params['label_type_sub'],
                                eval_batch_size=1,
                                is_multitask=True)
                            print('  PER: %f %%' % (per_test * 100))

                        duration_eval = time.time() - start_time_eval
                        print('Evaluation time: %.3f min' %
                              (duration_eval / 60))

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=train_data.epoch,
                            value=cer_dev_epoch)

                    start_time_epoch = time.time()

            duration_train = time.time() - start_time_train
            print('Total time: %.3f hour' % (duration_train / 3600))

            # Training was finished correctly
            with open(join(model.save_path, 'complete.txt'), 'w') as f:
                f.write('')
Beispiel #8
0
def do_eval(model, params, epoch, batch_size, beam_width):
    """Evaluate the model.
    Args:
        model: the model to restore
        params (dict): A dictionary of parameters
        epoch (int): the epoch to restore
        batch_size (int): the size of mini-batch when evaluation
        beam_width (int): beam_width (int, optional): beam width for beam search.
            1 disables beam search, which mean greedy decoding.
    """
    # Load dataset
    test_data = Dataset(data_type='test',
                        label_type_main=params['label_type_main'],
                        label_type_sub='phone39',
                        batch_size=batch_size,
                        splice=params['splice'],
                        num_stack=params['num_stack'],
                        num_skip=params['num_skip'],
                        shuffle=False,
                        progressbar=True)

    # Define placeholders
    model.create_placeholders()

    # Add to the graph each operation
    _, logits_main, logits_sub = model.compute_loss(
        model.inputs_pl_list[0], model.labels_pl_list[0],
        model.labels_sub_pl_list[0], model.inputs_seq_len_pl_list[0],
        model.keep_prob_input_pl_list[0], model.keep_prob_hidden_pl_list[0],
        model.keep_prob_output_pl_list[0])
    decode_op_main, decode_op_sub = model.decoder(
        logits_main,
        logits_sub,
        model.inputs_seq_len_pl_list[0],
        beam_width=beam_width)
    _, per_op = model.compute_ler(decode_op_main, decode_op_sub,
                                  model.labels_pl_list[0],
                                  model.labels_sub_pl_list[0])

    # Create a saver for writing training checkpoints
    saver = tf.train.Saver()

    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(model.save_path)

        # If check point exists
        if ckpt:
            # Use last saved model
            model_path = ckpt.model_checkpoint_path
            if epoch != -1:
                # Use the best model
                # NOTE: In the training stage, parameters are saved only when
                # accuracies are improved
                model_path = model_path.split('/')[:-1]
                model_path = '/'.join(model_path) + '/model.ckpt-' + str(epoch)
            saver.restore(sess, model_path)
            print("Model restored: " + model_path)
        else:
            raise ValueError('There are not any checkpoints.')

        print('=== Test Data Evaluation ===')
        cer_test, wer_test = do_eval_cer(session=sess,
                                         decode_op=decode_op_main,
                                         model=model,
                                         dataset=test_data,
                                         label_type=params['label_type_main'],
                                         eval_batch_size=1,
                                         progressbar=True,
                                         is_multitask=True)
        print('  WER: %f %%' % (wer_test * 100))
        print('  CER: %f %%' % (cer_test * 100))

        per_test = do_eval_per(session=sess,
                               decode_op=decode_op_sub,
                               per_op=per_op,
                               model=model,
                               dataset=test_data,
                               label_type=params['label_type_sub'],
                               eval_batch_size=1,
                               progressbar=True,
                               is_multitask=True)
        print('  PER: %f %%' % (per_test * 100))
def do_plot(network, param, epoch=None):
    """Plot the multi-task CTC posteriors.
    Args:
        network: model to restore
        param: A dictionary of parameters
        epoch: int, the epoch to restore
    """
    # Load dataset
    test_data = Dataset(data_type='test',
                        label_type_main='character',
                        label_type_sub=param['label_type_sub'],
                        batch_size=1,
                        num_stack=param['num_stack'],
                        num_skip=param['num_skip'],
                        is_sorted=False,
                        is_progressbar=True)

    # Define placeholders
    network.inputs = tf.placeholder(tf.float32,
                                    shape=[None, None, network.input_size],
                                    name='input')
    indices_pl = tf.placeholder(tf.int64, name='indices')
    values_pl = tf.placeholder(tf.int32, name='values')
    shape_pl = tf.placeholder(tf.int64, name='shape')
    network.labels = tf.SparseTensor(indices_pl, values_pl, shape_pl)
    indices_sub_pl = tf.placeholder(tf.int64, name='indices_sub')
    values_sub_pl = tf.placeholder(tf.int32, name='values_sub')
    shape_sub_pl = tf.placeholder(tf.int64, name='shape_sub')
    network.labels_sub = tf.SparseTensor(indices_sub_pl, values_sub_pl,
                                         shape_sub_pl)
    network.inputs_seq_len = tf.placeholder(tf.int64,
                                            shape=[None],
                                            name='inputs_seq_len')
    network.keep_prob_input = tf.placeholder(tf.float32,
                                             name='keep_prob_input')
    network.keep_prob_hidden = tf.placeholder(tf.float32,
                                              name='keep_prob_hidden')

    # Add to the graph each operation (including model definition)
    _, logits_main, logits_sub = network.compute_loss(network.inputs,
                                                      network.labels,
                                                      network.labels_sub,
                                                      network.inputs_seq_len,
                                                      network.keep_prob_input,
                                                      network.keep_prob_hidden)
    posteriors_op_main, posteriors_op_sub = network.posteriors(
        logits_main, logits_sub)

    # Create a saver for writing training checkpoints
    saver = tf.train.Saver()

    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(network.model_dir)

        # If check point exists
        if ckpt:
            # Use last saved model
            model_path = ckpt.model_checkpoint_path
            if epoch is not None:
                model_path = model_path.split('/')[:-1]
                model_path = '/'.join(model_path) + '/model.ckpt-' + str(epoch)
            saver.restore(sess, model_path)
            print("Model restored: " + model_path)
        else:
            raise ValueError('There are not any checkpoints.')

        # Visualize
        posterior_test_multitask(session=sess,
                                 posteriors_op_main=posteriors_op_main,
                                 posteriors_op_sub=posteriors_op_sub,
                                 network=network,
                                 dataset=test_data,
                                 label_type_sub=param['label_type_sub'],
                                 save_path=network.model_dir,
                                 show=False)