def do_eval(network, label_type_second, num_stack, num_skip, epoch=None):
    """Evaluate the model.
    Args:
        network: model to restore
        label_type_second: string, phone39 or phone48 or phone61
        num_stack: int, the number of frames to stack
        num_skip: int, the number of frames to skip
        epoch: int, the epoch to restore
    """
    # Load dataset
    if label_type_second == 'character':
        test_data = DataSet(data_type='test',
                            label_type_second='character',
                            batch_size=1,
                            num_stack=num_stack,
                            num_skip=num_skip,
                            is_sorted=False,
                            is_progressbar=True)
    else:
        test_data = DataSet(data_type='test',
                            label_type_second='phone39',
                            batch_size=1,
                            num_stack=num_stack,
                            num_skip=num_skip,
                            is_sorted=False,
                            is_progressbar=True)

    # Define placeholders
    network.inputs = tf.placeholder(tf.float32,
                                    shape=[None, None, network.input_size],
                                    name='input')
    indices_pl = tf.placeholder(tf.int64, name='indices')
    values_pl = tf.placeholder(tf.int32, name='values')
    shape_pl = tf.placeholder(tf.int64, name='shape')
    network.labels = tf.SparseTensor(indices_pl, values_pl, shape_pl)
    indices_second_pl = tf.placeholder(tf.int64, name='indices_second')
    values_second_pl = tf.placeholder(tf.int32, name='values_second')
    shape_second_pl = tf.placeholder(tf.int64, name='shape_second')
    network.labels_second = tf.SparseTensor(indices_second_pl,
                                            values_second_pl, shape_second_pl)
    network.inputs_seq_len = tf.placeholder(tf.int64,
                                            shape=[None],
                                            name='inputs_seq_len')

    # Add to the graph each operation
    _, logits_main, logits_second = network.compute_loss(
        network.inputs, network.labels, network.labels_second,
        network.inputs_seq_len)
    decode_op_main, decode_op_second = network.decoder(
        logits_main,
        logits_second,
        network.inputs_seq_len,
        decode_type='beam_search',
        beam_width=20)
    per_op_main, per_op_second = network.compute_ler(decode_op_main,
                                                     decode_op_second,
                                                     network.labels,
                                                     network.labels_second)

    # Create a saver for writing training checkpoints
    saver = tf.train.Saver()

    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(network.model_dir)

        # If check point exists
        if ckpt:
            # Use last saved model
            model_path = ckpt.model_checkpoint_path
            if epoch is not None:
                model_path = model_path.split('/')[:-1]
                model_path = '/'.join(model_path) + '/model.ckpt-' + str(epoch)
            saver.restore(sess, model_path)
            print("Model restored: " + model_path)
        else:
            raise ValueError('There are not any checkpoints.')

        print('=== Test Data Evaluation ===')
        cer_test = do_eval_cer(session=sess,
                               decode_op=decode_op_main,
                               network=network,
                               dataset=test_data,
                               is_progressbar=True,
                               is_multitask=True)
        print('  CER: %f %%' % (cer_test * 100))

        per_test = do_eval_per(session=sess,
                               decode_op=decode_op_second,
                               per_op=per_op_second,
                               network=network,
                               dataset=test_data,
                               train_label_type=label_type_second,
                               is_progressbar=True,
                               is_multitask=True)
        print('  PER: %f %%' % (per_test * 100))
def do_plot(network, label_type_second, num_stack, num_skip, epoch=None):
    """Plot the Multi-task CTC posteriors.
    Args:
        network: model to restore
        label_type_second: string, phone39 or phone48 or phone61
        num_stack: int, the number of frames to stack
        num_skip: int, the number of frames to skip
        epoch: int, the epoch to restore
    """
    # Load dataset
    test_data = DataSet(data_type='test',
                        label_type_second=label_type_second,
                        batch_size=1,
                        num_stack=num_stack,
                        num_skip=num_skip,
                        is_sorted=False,
                        is_progressbar=True)

    # Define placeholders
    network.inputs = tf.placeholder(tf.float32,
                                    shape=[None, None, network.input_size],
                                    name='input')
    indices_pl = tf.placeholder(tf.int64, name='indices')
    values_pl = tf.placeholder(tf.int32, name='values')
    shape_pl = tf.placeholder(tf.int64, name='shape')
    network.labels = tf.SparseTensor(indices_pl, values_pl, shape_pl)
    indices_second_pl = tf.placeholder(tf.int64, name='indices_second')
    values_second_pl = tf.placeholder(tf.int32, name='values_second')
    shape_second_pl = tf.placeholder(tf.int64, name='shape_second')
    network.labels_second = tf.SparseTensor(indices_second_pl,
                                            values_second_pl, shape_second_pl)
    network.inputs_seq_len = tf.placeholder(tf.int64,
                                            shape=[None],
                                            name='inputs_seq_len')
    network.keep_prob_input = tf.placeholder(tf.float32,
                                             name='keep_prob_input')
    network.keep_prob_hidden = tf.placeholder(tf.float32,
                                              name='keep_prob_hidden')

    # Add to the graph each operation (including model definition)
    _, logits_main, logits_second = network.compute_loss(
        network.inputs, network.labels, network.labels_second,
        network.inputs_seq_len, network.keep_prob_input,
        network.keep_prob_hidden)
    posteriors_op_main, posteriors_op_second = network.posteriors(
        logits_main, logits_second)

    # Create a saver for writing training checkpoints
    saver = tf.train.Saver()

    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(network.model_dir)

        # If check point exists
        if ckpt:
            # Use last saved model
            model_path = ckpt.model_checkpoint_path
            if epoch is not None:
                model_path = model_path.split('/')[:-1]
                model_path = '/'.join(model_path) + '/model.ckpt-' + str(epoch)
            saver.restore(sess, model_path)
            print("Model restored: " + model_path)
        else:
            raise ValueError('There are not any checkpoints.')

        # Visualize
        posterior_test_multitask(session=sess,
                                 posteriors_op_main=posteriors_op_main,
                                 posteriors_op_second=posteriors_op_second,
                                 network=network,
                                 dataset=test_data,
                                 label_type_second=label_type_second,
                                 save_path=network.model_dir)
Exemplo n.º 3
0
def do_train(network, optimizer, learning_rate, batch_size, epoch_num, label_type, num_stack, num_skip):
    """Run training.
    Args:
        network: network to train
        optimizer: string, the name of optimizer. ex.) adam, rmsprop
        learning_rate: initial learning rate
        batch_size: size of mini batch
        epoch_num: epoch num to train
        label_type: phone39 or phone48 or phone61 (+ character)
        num_stack: int, the number of frames to stack
        num_skip: int, the number of frames to skip
    """
    # Load dataset
    train_data = DataSet(data_type='train', label_type=label_type,
                         num_stack=num_stack, num_skip=num_skip,
                         is_sorted=True)
    dev_data61 = DataSet(data_type='dev', label_type='phone61',
                         num_stack=num_stack, num_skip=num_skip,
                         is_sorted=False)
    dev_data39 = DataSet(data_type='dev', label_type='phone39',
                         num_stack=num_stack, num_skip=num_skip,
                         is_sorted=False)
    test_data = DataSet(data_type='test', label_type='phone39',
                        num_stack=num_stack, num_skip=num_skip,
                        is_sorted=False)

    # Tell TensorFlow that the model will be built into the default graph
    with tf.Graph().as_default():

        # Define model
        network.define()
        # NOTE: define model under tf.Graph()

        # Add to the graph each operation
        loss_op = network.loss()
        train_op = network.train(optimizer=optimizer,
                                 learning_rate_init=learning_rate,
                                 is_scheduled=False)
        decode_op1, decode_op2 = network.decoder(decode_type='beam_search',
                                                 beam_width=20)
        per_op1, per_op2 = network.ler(decode_op1, decode_op2)

        # Build the summary tensor based on the TensorFlow collection of
        # summaries
        summary_train = tf.summary.merge(network.summaries_train)
        summary_dev = tf.summary.merge(network.summaries_dev)

        # Add the variable initializer operation
        init_op = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints
        saver = tf.train.Saver(max_to_keep=None)

        # Count total parameters
        parameters_dict, total_parameters = count_total_parameters(
            tf.trainable_variables())
        for parameter_name in sorted(parameters_dict.keys()):
            print("%s %d" % (parameter_name, parameters_dict[parameter_name]))
        print("Total %d variables, %s M parameters" %
              (len(parameters_dict.keys()), "{:,}".format(total_parameters / 1000000)))

        csv_steps = []
        csv_train_loss = []
        csv_dev_loss = []
        # Create a session for running operation on the graph
        with tf.Session() as sess:

            # Instantiate a SummaryWriter to output summaries and the graph
            summary_writer = tf.summary.FileWriter(
                network.model_dir, sess.graph)

            # Initialize parameters
            sess.run(init_op)

            # Train model
            iter_per_epoch = int(train_data.data_num / batch_size)
            if (train_data.data_num / batch_size) != int(train_data.data_num / batch_size):
                iter_per_epoch += 1
            max_steps = iter_per_epoch * epoch_num
            start_time_train = time.time()
            start_time_epoch = time.time()
            start_time_step = time.time()
            error_best = 1
            for step in range(max_steps):

                # Create feed dictionary for next mini batch (train)
                inputs, labels_char, labels_phone, seq_len, _ = train_data.next_batch(
                    batch_size=batch_size)
                indices_char, values_char, dense_shape_char = list2sparsetensor(
                    labels_char)
                indices_phone, values_phone, dense_shape_phone = list2sparsetensor(
                    labels_phone)
                feed_dict_train = {
                    network.inputs_pl: inputs,
                    network.label_indices_pl: indices_char,
                    network.label_values_pl: values_char,
                    network.label_shape_pl: dense_shape_char,
                    network.label_indices_pl2: indices_phone,
                    network.label_values_pl2: values_phone,
                    network.label_shape_pl2: dense_shape_phone,
                    network.seq_len_pl: seq_len,
                    network.keep_prob_input_pl: network.dropout_ratio_input,
                    network.keep_prob_hidden_pl: network.dropout_ratio_hidden,
                    network.lr_pl: learning_rate
                }

                # Create feed dictionary for next mini batch (dev)
                inputs, labels_char, labels_phone, seq_len, _ = dev_data61.next_batch(
                    batch_size=batch_size)
                indices_char, values_char, dense_shape_char = list2sparsetensor(
                    labels_char)
                indices_phone, values_phone, dense_shape_phone = list2sparsetensor(
                    labels_phone)
                feed_dict_dev = {
                    network.inputs_pl: inputs,
                    network.label_indices_pl: indices_char,
                    network.label_values_pl: values_char,
                    network.label_shape_pl: dense_shape_char,
                    network.label_indices_pl2: indices_phone,
                    network.label_values_pl2: values_phone,
                    network.label_shape_pl2: dense_shape_phone,
                    network.seq_len_pl: seq_len,
                    network.keep_prob_input_pl: network.dropout_ratio_input,
                    network.keep_prob_hidden_pl: network.dropout_ratio_hidden
                }

                # Update parameters & compute loss
                _, loss_train = sess.run(
                    [train_op, loss_op], feed_dict=feed_dict_train)
                loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev)
                csv_steps.append(step)
                csv_train_loss.append(loss_train)
                csv_dev_loss.append(loss_dev)

                if (step + 1) % 10 == 0:

                    # Change feed dict for evaluation
                    feed_dict_train[network.keep_prob_input_pl] = 1.0
                    feed_dict_train[network.keep_prob_hidden_pl] = 1.0
                    feed_dict_dev[network.keep_prob_input_pl] = 1.0
                    feed_dict_dev[network.keep_prob_hidden_pl] = 1.0

                    # Compute accuracy & update event file
                    cer_train, per_train, summary_str_train = sess.run([per_op1, per_op2, summary_train],
                                                                       feed_dict=feed_dict_train)
                    cer_dev, per_dev, summary_str_dev = sess.run([per_op1, per_op2,  summary_dev],
                                                                 feed_dict=feed_dict_dev)
                    summary_writer.add_summary(summary_str_train, step + 1)
                    summary_writer.add_summary(summary_str_dev, step + 1)
                    summary_writer.flush()

                    duration_step = time.time() - start_time_step
                    print('Step %d: loss = %.3f (%.3f) / cer = %.4f (%.4f) / per = %.4f (%.4f) (%.3f min)' %
                          (step + 1, loss_train, loss_dev, cer_train, cer_dev, per_train, per_dev, duration_step / 60))
                    sys.stdout.flush()
                    start_time_step = time.time()

                # Save checkpoint and evaluate model per epoch
                if (step + 1) % iter_per_epoch == 0 or (step + 1) == max_steps:
                    duration_epoch = time.time() - start_time_epoch
                    epoch = (step + 1) // iter_per_epoch
                    print('-----EPOCH:%d (%.3f min)-----' %
                          (epoch, duration_epoch / 60))

                    # Save model (check point)
                    checkpoint_file = join(network.model_dir, 'model.ckpt')
                    save_path = saver.save(
                        sess, checkpoint_file, global_step=epoch)
                    print("Model saved in file: %s" % save_path)

                    if epoch >= 10:
                        start_time_eval = time.time()

                        print('■Dev Data Evaluation:■')
                        error_epoch = do_eval_cer(session=sess,
                                                  decode_op=decode_op1,
                                                  network=network,
                                                  dataset=dev_data39,
                                                  eval_batch_size=1,
                                                  is_multitask=True)
                        do_eval_per(session=sess,
                                    decode_op=decode_op2,
                                    per_op=per_op2,
                                    network=network,
                                    dataset=dev_data39,
                                    label_type=label_type,
                                    eval_batch_size=1,
                                    is_multitask=True)

                        if error_epoch < error_best:
                            error_best = error_epoch
                            print('■■■ ↑Best Score (CER)↑ ■■■')

                            print('■Test Data Evaluation:■')
                            do_eval_cer(session=sess,
                                        decode_op=decode_op1,
                                        network=network,
                                        dataset=test_data,
                                        eval_batch_size=1,
                                        is_multitask=True)

                            do_eval_per(session=sess,
                                        decode_op=decode_op2,
                                        per_op=per_op2,
                                        network=network,
                                        dataset=test_data,
                                        label_type=label_type,
                                        eval_batch_size=1,
                                        is_multitask=True)

                        duration_eval = time.time() - start_time_eval
                        print('Evaluation time: %.3f min' %
                              (duration_eval / 60))

                    start_time_epoch = time.time()
                    start_time_step = time.time()

            duration_train = time.time() - start_time_train
            print('Total time: %.3f hour' % (duration_train / 3600))

            # Save train & dev loss
            save_loss(csv_steps, csv_train_loss, csv_dev_loss,
                      save_path=network.model_dir)

            # Training was finished correctly
            with open(join(network.model_dir, 'complete.txt'), 'w') as f:
                f.write('')
def do_train(network, optimizer, learning_rate, batch_size, epoch_num,
             label_type_second, num_stack, num_skip):
    """Run multi-task training. The target labels in the main task is
    characters and those in the second task is 61 phones. The model is
    evaluated by CER and PER with 39 phones.
    Args:
        network: network to train
        optimizer: string, the name of optimizer.
            ex.) adam, rmsprop
        learning_rate: A float value, the initial learning rate
        batch_size: int, the size of mini-batch
        epoch_num: int, the epoch num to train
        label_type_second: string, phone39 or phone48 or phone61
        num_stack: int, the number of frames to stack
        num_skip: int, the number of frames to skip
    """
    # Load dataset
    train_data = DataSet(data_type='train',
                         label_type_second=label_type_second,
                         batch_size=batch_size,
                         num_stack=num_stack, num_skip=num_skip,
                         is_sorted=True)
    dev_data = DataSet(data_type='dev', label_type_second=label_type_second,
                       batch_size=batch_size,
                       num_stack=num_stack, num_skip=num_skip,
                       is_sorted=False)
    test_data = DataSet(data_type='test', label_type_second='phone39',
                        batch_size=1,
                        num_stack=num_stack, num_skip=num_skip,
                        is_sorted=False)

    # Tell TensorFlow that the model will be built into the default graph
    with tf.Graph().as_default():

        # Define placeholders
        network.inputs = tf.placeholder(
            tf.float32,
            shape=[None, None, network.input_size],
            name='input')
        indices_pl = tf.placeholder(tf.int64, name='indices')
        values_pl = tf.placeholder(tf.int32, name='values')
        shape_pl = tf.placeholder(tf.int64, name='shape')
        network.labels = tf.SparseTensor(indices_pl, values_pl, shape_pl)
        indices_second_pl = tf.placeholder(tf.int64, name='indices_second')
        values_second_pl = tf.placeholder(tf.int32, name='values_second')
        shape_second_pl = tf.placeholder(tf.int64, name='shape_second')
        network.labels_second = tf.SparseTensor(indices_second_pl,
                                                values_second_pl,
                                                shape_second_pl)
        network.inputs_seq_len = tf.placeholder(tf.int64,
                                                shape=[None],
                                                name='inputs_seq_len')
        network.keep_prob_input = tf.placeholder(tf.float32,
                                                 name='keep_prob_input')
        network.keep_prob_hidden = tf.placeholder(tf.float32,
                                                  name='keep_prob_hidden')

        # Add to the graph each operation
        loss_op, logits_main, logits_second = network.compute_loss(
            network.inputs,
            network.labels,
            network.labels_second,
            network.inputs_seq_len,
            network.keep_prob_input,
            network.keep_prob_hidden)
        train_op = network.train(loss_op,
                                 optimizer=optimizer,
                                 learning_rate_init=float(learning_rate),
                                 is_scheduled=False)
        decode_op_main, decode_op_second = network.decoder(
            logits_main,
            logits_second,
            network.inputs_seq_len,
            decode_type='beam_search',
            beam_width=20)
        ler_op_main, ler_op_second = network.compute_ler(
            decode_op_main, decode_op_second,
            network.labels, network.labels_second)

        # Build the summary tensor based on the TensorFlow collection of
        # summaries
        summary_train = tf.summary.merge(network.summaries_train)
        summary_dev = tf.summary.merge(network.summaries_dev)

        # Add the variable initializer operation
        init_op = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints
        saver = tf.train.Saver(max_to_keep=None)

        # Count total parameters
        parameters_dict, total_parameters = count_total_parameters(
            tf.trainable_variables())
        for parameter_name in sorted(parameters_dict.keys()):
            print("%s %d" % (parameter_name, parameters_dict[parameter_name]))
        print("Total %d variables, %s M parameters" %
              (len(parameters_dict.keys()),
               "{:,}".format(total_parameters / 1000000)))

        csv_steps, csv_loss_train, csv_loss_dev = [], [], []
        csv_cer_train, csv_cer_dev = [], []
        csv_per_train, csv_per_dev = [], []
        # Create a session for running operation on the graph
        with tf.Session() as sess:

            # Instantiate a SummaryWriter to output summaries and the graph
            summary_writer = tf.summary.FileWriter(
                network.model_dir, sess.graph)

            # Initialize parameters
            sess.run(init_op)

            # Make mini-batch generator
            mini_batch_train = train_data.next_batch()
            mini_batch_dev = dev_data.next_batch()

            # Train model
            iter_per_epoch = int(train_data.data_num / batch_size)
            train_step = train_data.data_num / batch_size
            if train_step != int(train_step):
                iter_per_epoch += 1
            max_steps = iter_per_epoch * epoch_num
            start_time_train = time.time()
            start_time_epoch = time.time()
            start_time_step = time.time()
            cer_dev_best = 1
            for step in range(max_steps):

                # Create feed dictionary for next mini batch (train)
                inputs, labels_char_st, labels_phone_st, inputs_seq_len, input_names = mini_batch_train.__next__()
                feed_dict_train = {
                    network.inputs: inputs,
                    network.labels: labels_char_st,
                    network.labels_second: labels_phone_st,
                    network.inputs_seq_len: inputs_seq_len,
                    network.keep_prob_input: network.dropout_ratio_input,
                    network.keep_prob_hidden: network.dropout_ratio_hidden,
                    network.lr: learning_rate
                }

                # Create feed dictionary for next mini batch (dev)
                inputs, labels_char, labels_phone, inputs_seq_len, _ = mini_batch_dev.__next__()
                feed_dict_dev = {
                    network.inputs: inputs,
                    network.labels: labels_char_st,
                    network.labels_second: labels_phone_st,
                    network.inputs_seq_len: inputs_seq_len,
                    network.keep_prob_input: network.dropout_ratio_input,
                    network.keep_prob_hidden: network.dropout_ratio_hidden
                }

                # Update parameters
                sess.run(train_op, feed_dict=feed_dict_train)

                if (step + 1) % 10 == 0:
                    # Compute loss
                    loss_train = sess.run(loss_op, feed_dict=feed_dict_train)
                    loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev)
                    csv_steps.append(step)
                    csv_loss_train.append(loss_train)
                    csv_loss_dev.append(loss_dev)

                    # Change to evaluation mode
                    feed_dict_train[network.keep_prob_input] = 1.0
                    feed_dict_train[network.keep_prob_hidden] = 1.0
                    feed_dict_dev[network.keep_prob_input] = 1.0
                    feed_dict_dev[network.keep_prob_hidden] = 1.0

                    # Compute accuracy & update event file
                    cer_train, per_train, summary_str_train = sess.run(
                        [ler_op_main, ler_op_second, summary_train],
                        feed_dict=feed_dict_train)
                    cer_dev, per_dev, summary_str_dev = sess.run(
                        [ler_op_main, ler_op_second,  summary_dev],
                        feed_dict=feed_dict_dev)
                    csv_cer_train.append(cer_train)
                    csv_cer_dev.append(cer_dev)
                    csv_per_train.append(per_train)
                    csv_per_dev.append(per_dev)
                    summary_writer.add_summary(summary_str_train, step + 1)
                    summary_writer.add_summary(summary_str_dev, step + 1)
                    summary_writer.flush()

                    duration_step = time.time() - start_time_step
                    print("Step % d: loss = %.3f (%.3f) / cer = %.4f (%.4f) / per = % .4f (%.4f) (%.3f min)" %
                          (step + 1, loss_train, loss_dev, cer_train, cer_dev,
                           per_train, per_dev, duration_step / 60))
                    sys.stdout.flush()
                    start_time_step = time.time()

                # Save checkpoint and evaluate model per epoch
                if (step + 1) % iter_per_epoch == 0 or (step + 1) == max_steps:
                    duration_epoch = time.time() - start_time_epoch
                    epoch = (step + 1) // iter_per_epoch
                    print('-----EPOCH:%d (%.3f min)-----' %
                          (epoch, duration_epoch / 60))

                    # Save model (check point)
                    checkpoint_file = join(network.model_dir, 'model.ckpt')
                    save_path = saver.save(
                        sess, checkpoint_file, global_step=epoch)
                    print("Model saved in file: %s" % save_path)

                    if epoch >= 10:
                        start_time_eval = time.time()
                        print('=== Dev Data Evaluation ===')
                        cer_dev_epoch = do_eval_cer(
                            session=sess,
                            decode_op=decode_op_main,
                            network=network,
                            dataset=dev_data,
                            is_multitask=True)
                        print('  CER: %f %%' % (cer_dev_epoch * 100))
                        per_dev_epoch = do_eval_per(
                            session=sess,
                            decode_op=decode_op_second,
                            per_op=ler_op_second,
                            network=network,
                            dataset=dev_data,
                            train_label_type=label_type_second,
                            is_multitask=True)
                        print('  PER: %f %%' % (per_dev_epoch * 100))

                        if cer_dev_epoch < cer_dev_best:
                            cer_dev_best = cer_dev_epoch
                            print('■■■ ↑Best Score (CER)↑ ■■■')

                            print('=== Test Data Evaluation ===')
                            cer_test = do_eval_cer(
                                session=sess,
                                decode_op=decode_op_main,
                                network=network,
                                dataset=test_data,
                                eval_batch_size=1,
                                is_multitask=True)
                            print('  CER: %f %%' % (cer_test * 100))
                            per_test = do_eval_per(
                                session=sess,
                                decode_op=decode_op_second,
                                per_op=ler_op_second,
                                network=network,
                                dataset=test_data,
                                train_label_type=label_type_second,
                                eval_batch_size=1,
                                is_multitask=True)
                            print('  PER: %f %%' % (per_test * 100))

                        duration_eval = time.time() - start_time_eval
                        print('Evaluation time: %.3f min' %
                              (duration_eval / 60))

                        start_time_epoch = time.time()
                        start_time_step = time.time()

            duration_train = time.time() - start_time_train
            print('Total time: %.3f hour' % (duration_train / 3600))

            # Save train & dev loss, ler
            save_loss(csv_steps, csv_loss_train, csv_loss_dev,
                      save_path=network.model_dir)
            save_ler(csv_steps, csv_cer_train, csv_cer_dev,
                     save_path=network.model_dir)
            save_ler(csv_steps, csv_per_train, csv_per_dev,
                     save_path=network.model_dir)

            # Training was finished correctly
            with open(join(network.model_dir, 'complete.txt'), 'w') as f:
                f.write('')
Exemplo n.º 5
0
def do_restore(network, label_type, num_stack, num_skip, epoch=None):
    """Restore model.
    Args:
        network: model to restore
        label_type: phone39 or phone48 or phone61
        num_stack: int, the number of frames to stack
        num_skip: int, the number of frames to skip
        epoch: epoch to restore
    """
    # Load dataset
    if label_type == 'character':
        test_data = DataSet(data_type='test',
                            label_type='character',
                            num_stack=num_stack,
                            num_skip=num_skip,
                            is_sorted=False,
                            is_progressbar=True)
    else:
        test_data = DataSet(data_type='test',
                            label_type='phone39',
                            num_stack=num_stack,
                            num_skip=num_skip,
                            is_sorted=False,
                            is_progressbar=True)

    # Define model
    network.define()

    # Add to the graph each operation
    decode_op1, decode_op2 = network.decoder(decode_type='beam_search',
                                             beam_width=20)
    # posteriors_op = network.posteriors(decode_op1)
    per_op1, per_op2 = network.ler(decode_op1, decode_op2)

    # Create a saver for writing training checkpoints
    saver = tf.train.Saver()

    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(network.model_dir)

        # If check point exists
        if ckpt:
            # Use last saved model
            model_path = ckpt.model_checkpoint_path
            if epoch is not None:
                model_path = model_path.split('/')[:-1]
                model_path = '/'.join(model_path) + '/model.ckpt-' + str(epoch)
            saver.restore(sess, model_path)
            print("Model restored: " + model_path)
        else:
            raise ValueError('There are not any checkpoints.')

        print('Test Data Evaluation:')
        do_eval_cer(session=sess,
                    decode_op=decode_op1,
                    network=network,
                    dataset=test_data,
                    is_progressbar=True,
                    is_multitask=True)

        do_eval_per(session=sess,
                    decode_op=decode_op2,
                    per_op=per_op2,
                    network=network,
                    dataset=test_data,
                    label_type=label_type,
                    is_progressbar=True,
                    is_multitask=True)