def do_train(model, params):
    """Run training.
    Args:
        model: the model to train
        params (dict): A dictionary of parameters
    """
    # Load dataset
    train_data = Dataset(
        data_type='train', label_type=params['label_type'],
        batch_size=params['batch_size'], max_epoch=params['num_epoch'],
        splice=params['splice'],
        num_stack=params['num_stack'], num_skip=params['num_skip'],
        shuffle=True)
    dev_data = Dataset(
        data_type='dev', label_type=params['label_type'],
        batch_size=params['batch_size'], splice=params['splice'],
        num_stack=params['num_stack'], num_skip=params['num_skip'],
        shuffle=False)
    test_data = Dataset(
        data_type='dev', label_type=params['label_type'],
        batch_size=params['batch_size'], splice=params['splice'],
        num_stack=params['num_stack'], num_skip=params['num_skip'],
        shuffle=False)

    # Tell TensorFlow that the model will be built into the default graph
    with tf.Graph().as_default():

        # Define placeholders
        model.create_placeholders()
        learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

        # Add to the graph each operation (including model definition)
        loss_op, logits = model.compute_loss(
            model.inputs_pl_list[0],
            model.labels_pl_list[0],
            model.inputs_seq_len_pl_list[0],
            model.keep_prob_pl_list[0])
        train_op = model.train(
            loss_op,
            optimizer=params['optimizer'],
            learning_rate=learning_rate_pl)
        decode_op = model.decoder(logits,
                                  model.inputs_seq_len_pl_list[0],
                                  beam_width=params['beam_width'])
        ler_op = model.compute_ler(decode_op, model.labels_pl_list[0])
        posteriors_op = model.posteriors(logits, blank_prior=1)

        # Define learning rate controller
        lr_controller = Controller(
            learning_rate_init=params['learning_rate'],
            decay_start_epoch=params['decay_start_epoch'],
            decay_rate=params['decay_rate'],
            decay_patient_epoch=params['decay_patient_epoch'],
            lower_better=False)

        # Build the summary tensor based on the TensorFlow collection of
        # summaries
        summary_train = tf.summary.merge(model.summaries_train)
        summary_dev = tf.summary.merge(model.summaries_dev)

        # Add the variable initializer operation
        init_op = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints
        saver = tf.train.Saver(max_to_keep=None)

        # Count total parameters
        parameters_dict, total_parameters = count_total_parameters(
            tf.trainable_variables())
        for parameter_name in sorted(parameters_dict.keys()):
            print("%s %d" % (parameter_name, parameters_dict[parameter_name]))
        print("Total %d variables, %s M parameters" %
              (len(parameters_dict.keys()),
               "{:,}".format(total_parameters / 1000000)))

        csv_steps, csv_loss_train, csv_loss_dev = [], [], []
        csv_ler_train, csv_ler_dev = [], []
        # Create a session for running operation on the graph
        with tf.Session() as sess:

            # Instantiate a SummaryWriter to output summaries and the graph
            summary_writer = tf.summary.FileWriter(
                model.save_path, sess.graph)

            # Initialize parameters
            sess.run(init_op)

            # Train model
            start_time_train = time.time()
            start_time_epoch = time.time()
            start_time_step = time.time()
            fmean_dev_best = 0
            fmean_time_dev_best = 0
            learning_rate = float(params['learning_rate'])
            for step, (data, is_new_epoch) in enumerate(train_data):

                # Create feed dictionary for next mini batch (train)
                inputs, labels, inputs_seq_len, _ = data
                feed_dict_train = {
                    model.inputs_pl_list[0]: inputs[0],
                    model.labels_pl_list[0]: list2sparsetensor(
                        labels[0], padded_value=train_data.padded_value),
                    model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
                    model.keep_prob_pl_list[0]: 1 - float(params['dropout']),
                    learning_rate_pl: learning_rate
                }

                # Update parameters
                sess.run(train_op, feed_dict=feed_dict_train)

                if (step + 1) % params['print_step'] == 0:

                    # Create feed dictionary for next mini batch (dev)
                    (inputs, labels, inputs_seq_len, _), _ = dev_data.next()
                    feed_dict_dev = {
                        model.inputs_pl_list[0]: inputs[0],
                        model.labels_pl_list[0]: list2sparsetensor(
                            labels[0], padded_value=dev_data.padded_value),
                        model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
                        model.keep_prob_pl_list[0]: 1.0
                    }

                    # Compute loss
                    loss_train = sess.run(loss_op, feed_dict=feed_dict_train)
                    loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev)
                    csv_steps.append(step)
                    csv_loss_train.append(loss_train)
                    csv_loss_dev.append(loss_dev)

                    # Change to evaluation mode
                    feed_dict_train[model.keep_prob_pl_list[0]] = 1.0

                    # Compute accuracy & update event files
                    ler_train, summary_str_train = sess.run(
                        [ler_op, summary_train], feed_dict=feed_dict_train)
                    ler_dev, summary_str_dev = sess.run(
                        [ler_op, summary_dev], feed_dict=feed_dict_dev)
                    csv_ler_train.append(ler_train)
                    csv_ler_dev.append(ler_dev)
                    summary_writer.add_summary(summary_str_train, step + 1)
                    summary_writer.add_summary(summary_str_dev, step + 1)
                    summary_writer.flush()

                    duration_step = time.time() - start_time_step
                    print("Step %d (epoch: %.3f): loss = %.3f (%.3f) / ler = %.3f (%.3f) / lr = %.5f (%.3f min)" %
                          (step + 1, train_data.epoch_detail, loss_train, loss_dev, ler_train, ler_dev,
                           learning_rate, duration_step / 60))
                    sys.stdout.flush()
                    start_time_step = time.time()

                # Save checkpoint and evaluate model per epoch
                if is_new_epoch:
                    duration_epoch = time.time() - start_time_epoch
                    print('-----EPOCH:%d (%.3f min)-----' %
                          (train_data.epoch, duration_epoch / 60))

                    # Save fugure of loss & ler
                    plot_loss(csv_loss_train, csv_loss_dev, csv_steps,
                              save_path=model.save_path)
                    plot_ler(csv_ler_train, csv_ler_dev, csv_steps,
                             label_type=params['label_type'],
                             save_path=model.save_path)

                    if train_data.epoch >= params['eval_start_epoch']:
                        start_time_eval = time.time()
                        print('=== Dev Data Evaluation ===')
                        fmean_dev_epoch, df_acc = do_eval_fmeasure(
                            session=sess,
                            decode_op=decode_op,
                            model=model,
                            dataset=dev_data,
                            eval_batch_size=params['batch_size'])
                        print(df_acc)
                        print('  F-measure: %f %%' % (fmean_dev_epoch))

                        if fmean_dev_epoch > fmean_dev_best:
                            fmean_dev_best = fmean_dev_epoch
                            print('■■■ ↑Best Score (F-measure)↑ ■■■')

                            # Save model only when best accuracy is
                            # obtained (check point)
                            checkpoint_file = join(
                                model.save_path, 'model.ckpt')
                            save_path = saver.save(
                                sess, checkpoint_file, global_step=train_data.epoch)
                            print("Model saved in file: %s" % save_path)

                            print('=== Test Data Evaluation ===')
                            fmean_test_epoch, df_acc = do_eval_fmeasure(
                                session=sess,
                                decode_op=decode_op,
                                model=model,
                                dataset=test_data,
                                eval_batch_size=params['batch_size'])
                            print(df_acc)
                            print('  F-measure: %f %%' % (fmean_test_epoch))

                        # fmean_time_dev_epoch, df_acc = do_eval_fmeasure_time(
                        #     session=sess,
                        #     decode_op=decode_op,
                        #     posteriors_op=posteriors_op,
                        #     model=model,
                        #     dataset=dev_data,
                        #     eval_batch_size=params['batch_size'])
                        # print(df_acc)
                        # print('  Time F-measure: %f %%' %
                        #       (fmean_time_dev_epoch))

                        # if fmean_time_dev_best < fmean_time_dev_epoch:
                        #     fmean_time_dev_best = fmean_time_dev_epoch
                        #     print('■■■ ↑Best Score (Time F-measure)↑ ■■■')

                        # fmean_time_test_epoch, df_acc = do_eval_fmeasure_time(
                        #     session=sess,
                        #     decode_op=decode_op,
                        #     posteriors_op=posteriors_op,
                        #     model=model,
                        #     dataset=test_data,
                        #     eval_batch_size=params['batch_size'])
                        # print(df_acc)
                        # print('  Time F-measure: %f %%' %
                        #       (fmean_time_test_epoch))

                        duration_eval = time.time() - start_time_eval
                        print('Evaluation time: %.3f min' %
                              (duration_eval / 60))

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=train_data.epoch,
                            value=fmean_dev_epoch)

                    start_time_epoch = time.time()

            duration_train = time.time() - start_time_train
            print('Total time: %.3f hour' % (duration_train / 3600))

            # Training was finished correctly
            with open(join(model.save_path, 'complete.txt'), 'w') as f:
                f.write('')
    def check(self, encoder_type, lstm_impl=None, time_major=False):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  lstm_impl: %s' % lstm_impl)
        print('  time_major: %s' % time_major)
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 4
            splice = 5 if encoder_type in ['vgg_blstm', 'vgg_lstm',
                                           'vgg_wang', 'resnet_wang', 'cldnn_wang',
                                           'cnn_zhang'] else 1
            num_stack = 2
            inputs, _, inputs_seq_len = generate_data(
                label_type='character',
                model='ctc',
                batch_size=batch_size,
                num_stack=num_stack,
                splice=splice)
            frame_num, input_size = inputs[0].shape

            # Define model graph
            if encoder_type in ['blstm', 'lstm']:
                encoder = load(encoder_type)(
                    num_units=256,
                    num_proj=None,
                    num_layers=5,
                    lstm_impl=lstm_impl,
                    use_peephole=True,
                    parameter_init=0.1,
                    clip_activation=5,
                    time_major=time_major)
            elif encoder_type in ['bgru', 'gru']:
                encoder = load(encoder_type)(
                    num_units=256,
                    num_layers=5,
                    parameter_init=0.1,
                    time_major=time_major)
            elif encoder_type in ['vgg_blstm', 'vgg_lstm', 'cldnn_wang']:
                encoder = load(encoder_type)(
                    input_size=input_size // splice // num_stack,
                    splice=splice,
                    num_stack=num_stack,
                    num_units=256,
                    num_proj=None,
                    num_layers=5,
                    lstm_impl=lstm_impl,
                    use_peephole=True,
                    parameter_init=0.1,
                    clip_activation=5,
                    time_major=time_major)
            elif encoder_type in ['multitask_blstm', 'multitask_lstm']:
                encoder = load(encoder_type)(
                    num_units=256,
                    num_proj=None,
                    num_layers_main=5,
                    num_layers_sub=3,
                    lstm_impl=lstm_impl,
                    use_peephole=True,
                    parameter_init=0.1,
                    clip_activation=5,
                    time_major=time_major)
            elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                encoder = load(encoder_type)(
                    input_size=input_size // splice // num_stack,
                    splice=splice,
                    num_stack=num_stack,
                    parameter_init=0.1,
                    time_major=time_major)
                # NOTE: topology is pre-defined
            else:
                raise NotImplementedError

            # Create placeholders
            inputs_pl = tf.placeholder(tf.float32,
                                       shape=[None, None, input_size],
                                       name='inputs')
            inputs_seq_len_pl = tf.placeholder(tf.int32,
                                               shape=[None],
                                               name='inputs_seq_len')
            keep_prob_pl = tf.placeholder(tf.float32, name='keep_prob')

            # operation for forward computation
            if encoder_type in ['multitask_blstm', 'multitask_lstm']:
                hidden_states_op, final_state_op, hidden_states_sub_op, final_state_sub_op = encoder(
                    inputs=inputs_pl,
                    inputs_seq_len=inputs_seq_len_pl,
                    keep_prob=keep_prob_pl,
                    is_training=True)
            else:
                hidden_states_op, final_state_op = encoder(
                    inputs=inputs_pl,
                    inputs_seq_len=inputs_seq_len_pl,
                    keep_prob=keep_prob_pl,
                    is_training=True)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()),
                   "{:,}".format(total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                inputs_pl: inputs,
                inputs_seq_len_pl: inputs_seq_len,
                keep_prob_pl: 0.9
            }

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Make prediction
                if encoder_type in ['multitask_blstm', 'multitask_lstm']:
                    encoder_outputs, final_state, hidden_states_sub, final_state_sub = sess.run(
                        [hidden_states_op, final_state_op,
                         hidden_states_sub_op, final_state_sub_op],
                        feed_dict=feed_dict)
                elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                    encoder_outputs = sess.run(
                        hidden_states_op, feed_dict=feed_dict)
                else:
                    encoder_outputs, final_state = sess.run(
                        [hidden_states_op, final_state_op],
                        feed_dict=feed_dict)

                # Convert always to batch-major
                if time_major:
                    encoder_outputs = encoder_outputs.transpose(1, 0, 2)

                if encoder_type in ['blstm', 'bgru', 'vgg_blstm', 'multitask_blstm', 'cldnn_wang']:
                    if encoder_type != 'cldnn_wang':
                        self.assertEqual(
                            (batch_size, frame_num, encoder.num_units * 2), encoder_outputs.shape)

                    if encoder_type != 'bgru':
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].c.shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].h.shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[1].c.shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[1].h.shape)

                        if encoder_type == 'multitask_blstm':
                            self.assertEqual(
                                (batch_size, frame_num, encoder.num_units * 2), hidden_states_sub.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[0].c.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[0].h.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[1].c.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[1].h.shape)
                    else:
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[1].shape)

                elif encoder_type in ['lstm', 'gru', 'vgg_lstm', 'multitask_lstm']:
                    self.assertEqual(
                        (batch_size, frame_num, encoder.num_units), encoder_outputs.shape)

                    if encoder_type != 'gru':
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].c.shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].h.shape)

                        if encoder_type == 'multitask_lstm':
                            self.assertEqual(
                                (batch_size, frame_num, encoder.num_units), hidden_states_sub.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[0].c.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[0].h.shape)
                    else:
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].shape)

                elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                    self.assertEqual(3, len(encoder_outputs.shape))
                    self.assertEqual(
                        (batch_size, frame_num), encoder_outputs.shape[:2])
def do_fine_tune(network, optimizer, learning_rate, batch_size, epoch_num,
                 label_type, num_stack, num_skip, social_signal_type,
                 trained_model_path, restore_epoch=None):
    """Run training.
    Args:
        network: network to train
        optimizer: adam or adadelta or rmsprop
        learning_rate: initial learning rate
        batch_size: size of mini batch
        epoch_num: epoch num to train
        label_type: phone or character
        num_stack: int, the number of frames to stack
        num_skip: int, the number of frames to skip
        social_signal_type: insert or insert2 or insert3 or remove
        trained_model_path: path to the pre-trained model
        restore_epoch: epoch of the model to restore
    """
    # Tell TensorFlow that the model will be built into the default graph
    with tf.Graph().as_default():
        # Read dataset
        train_data = DataSetDialog(data_type='train', label_type=label_type,
                                   social_signal_type=social_signal_type,
                                   num_stack=num_stack, num_skip=num_skip,
                                   is_sorted=True)
        dev_data = DataSetDialog(data_type='dev', label_type=label_type,
                                 social_signal_type=social_signal_type,
                                 num_stack=num_stack, num_skip=num_skip,
                                 is_sorted=False)
        test_data = DataSetDialog(data_type='test', label_type=label_type,
                                  social_signal_type=social_signal_type,
                                  num_stack=num_stack, num_skip=num_skip,
                                  is_sorted=False)
        # TODO:作る
        # eval1_data = DataSet(data_type='eval1', label_type=label_type,
        #                      social_signal_type=social_signal_type,
        #                      num_stack=num_stack, num_skip=num_skip,
        #                      is_sorted=False)
        # eval2_data = DataSet(data_type='eval2', label_type=label_type,
        #                      social_signal_type=social_signal_type,
        #                      num_stack=num_stack, num_skip=num_skip,
        #                      is_sorted=False)
        # eval3_data = DataSet(data_type='eval3', label_type=label_type,
        #                      social_signal_type=social_signal_type,
        #                      num_stack=num_stack, num_skip=num_skip,
        #                      is_sorted=False)

        # Add to the graph each operation
        loss_op = network.loss()
        train_op = network.train(optimizer=optimizer,
                                 learning_rate_init=learning_rate,
                                 is_scheduled=False)
        decode_op = network.decoder(decode_type='beam_search',
                                    beam_width=20)
        per_op = network.ler(decode_op)

        # Build the summary tensor based on the TensorFlow collection of
        # summaries
        summary_train = tf.summary.merge(network.summaries_train)
        summary_dev = tf.summary.merge(network.summaries_dev)

        # Add the variable initializer operation
        init_op = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints
        saver = tf.train.Saver(max_to_keep=None)

        # Count total parameters
        parameters_dict, total_parameters = count_total_parameters(
            tf.trainable_variables())
        for parameter_name in sorted(parameters_dict.keys()):
            print("%s %d" % (parameter_name, parameters_dict[parameter_name]))
        print("Total %d variables, %s M parameters" %
              (len(parameters_dict.keys()), "{:,}".format(total_parameters / 1000000)))

        csv_steps = []
        csv_train_loss = []
        csv_dev_loss = []

        # Create a session for running operation on the graph
        with tf.Session() as sess:
            # Instantiate a SummaryWriter to output summaries and the graph
            summary_writer = tf.summary.FileWriter(
                network.model_dir, sess.graph)

            # Initialize parameters
            sess.run(init_op)

            # Restore pre-trained model's parameters
            ckpt = tf.train.get_checkpoint_state(trained_model_path)
            if ckpt:
                # Use last saved model
                model_path = ckpt.model_checkpoint_path
                if restore_epoch is not None:
                    model_path = model_path.split('/')[:-1]
                    model_path = '/'.join(model_path) + \
                        '/model.ckpt-' + str(restore_epoch)
            else:
                raise ValueError('There are not any checkpoints.')
            exclude = ['output/Variable', 'output/Variable_1']
            variables_to_restore = slim.get_variables_to_restore(
                exclude=exclude)
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(sess, model_path)
            print("Model restored: " + model_path)

            # Train model
            iter_per_epoch = int(train_data.data_num / batch_size)
            if (train_data.data_num / batch_size) != int(train_data.data_num / batch_size):
                iter_per_epoch += 1
            max_steps = iter_per_epoch * epoch_num
            start_time_train = time.time()
            start_time_epoch = time.time()
            start_time_step = time.time()
            fmean_best = 0
            for step in range(max_steps):
                # Create feed dictionary for next mini batch (train)
                inputs, labels, seq_len, _ = train_data.next_batch(
                    batch_size=batch_size)
                indices, values, dense_shape = list2sparsetensor(labels)
                feed_dict_train = {
                    network.inputs_pl: inputs,
                    network.label_indices_pl: indices,
                    network.label_values_pl: values,
                    network.label_shape_pl: dense_shape,
                    network.seq_len_pl: seq_len,
                    network.keep_prob_input_pl: network.dropout_ratio_input,
                    network.keep_prob_hidden_pl: network.dropout_ratio_hidden,
                    network.lr_pl: learning_rate
                }

                # Create feed dictionary for next mini batch (dev)
                inputs, labels, seq_len, _ = dev_data.next_batch(
                    batch_size=batch_size)
                indices, values, dense_shape = list2sparsetensor(labels)
                feed_dict_dev = {
                    network.inputs_pl: inputs,
                    network.label_indices_pl: indices,
                    network.label_values_pl: values,
                    network.label_shape_pl: dense_shape,
                    network.seq_len_pl: seq_len,
                    network.keep_prob_input_pl: network.dropout_ratio_input,
                    network.keep_prob_hidden_pl: network.dropout_ratio_hidden
                }

                # Update parameters & compute loss
                _, loss_train = sess.run(
                    [train_op, loss_op], feed_dict=feed_dict_train)
                loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev)
                csv_steps.append(step)
                csv_train_loss.append(loss_train)
                csv_dev_loss.append(loss_dev)

                if (step + 1) % 10 == 0:
                    # Change feed dict for evaluation
                    feed_dict_train[network.keep_prob_input_pl] = 1.0
                    feed_dict_train[network.keep_prob_hidden_pl] = 1.0
                    feed_dict_dev[network.keep_prob_input_pl] = 1.0
                    feed_dict_dev[network.keep_prob_hidden_pl] = 1.0

                    # Compute accuracy & \update event file
                    ler_train, summary_str_train = sess.run([per_op, summary_train],
                                                            feed_dict=feed_dict_train)
                    ler_dev, summary_str_dev, labels_st = sess.run([per_op, summary_dev, decode_op],
                                                                   feed_dict=feed_dict_dev)
                    summary_writer.add_summary(summary_str_train, step + 1)
                    summary_writer.add_summary(summary_str_dev, step + 1)
                    summary_writer.flush()

                    # Decode
                    # try:
                    #     labels_pred = sparsetensor2list(labels_st, batch_size)
                    # except:
                    #     labels_pred = [[0] * batch_size]

                    duration_step = time.time() - start_time_step
                    print('Step %d: loss = %.3f (%.3f) / ler = %.4f (%.4f) (%.3f min)' %
                          (step + 1, loss_train, loss_dev, ler_train, ler_dev, duration_step / 60))

                    # if label_type == 'character':
                    #     if social_signal_type == 'remove':
                    #         map_file_path = '../evaluation/mapping_files/ctc/char2num_remove.txt'
                    #     else:
                    #         map_file_path = '../evaluation/mapping_files/ctc/char2num_' + \
                    #             social_signal_type + '.txt'
                    #     print('True: %s' % num2char(labels[-1], map_file_path))
                    #     print('Pred: %s' % num2char(
                    #         labels_pred[-1], map_file_path))
                    # elif label_type == 'phone':
                    #     if social_signal_type == 'remove':
                    #         map_file_path = '../evaluation/mapping_files/ctc/phone2num_remove.txt'
                    #     else:
                    #         map_file_path = '../evaluation/mapping_files/ctc/phone2num_' + \
                    #             social_signal_type + '.txt'
                    #     print('True: %s' % num2phone(
                    #         labels[-1], map_file_path))
                    #     print('Pred: %s' % num2phone(
                    #         labels_pred[-1], map_file_path))

                    sys.stdout.flush()
                    start_time_step = time.time()

                # Save checkpoint and evaluate model per epoch
                if (step + 1) % iter_per_epoch == 0 or (step + 1) == max_steps:
                    duration_epoch = time.time() - start_time_epoch
                    epoch = (step + 1) // iter_per_epoch
                    print('-----EPOCH:%d (%.3f min)-----' %
                          (epoch, duration_epoch / 60))

                    # Save model (check point)
                    checkpoint_file = os.path.join(
                        network.model_dir, 'model.ckpt')
                    save_path = saver.save(
                        sess, checkpoint_file, global_step=epoch)
                    print("Model saved in file: %s" % save_path)

                    start_time_eval = time.time()
                    if label_type == 'character':
                        print('■Dev Evaluation:■')
                        fmean_epoch = do_eval_fmeasure(session=sess, decode_op=decode_op,
                                                       network=network, dataset=dev_data,
                                                       label_type=label_type,
                                                       social_signal_type=social_signal_type)
                        # error_epoch = do_eval_cer(session=sess,
                        #                           decode_op=decode_op,
                        #                           network=network,
                        #                           dataset=dev_data,
                        #                           eval_batch_size=batch_size)

                        if fmean_epoch > fmean_best:
                            fmean_best = fmean_epoch
                            print('■■■ ↑Best Score (F-measure)↑ ■■■')

                            do_eval_fmeasure(session=sess, decode_op=decode_op,
                                             network=network, dataset=test_data,
                                             label_type=label_type,
                                             social_signal_type=social_signal_type)
                            # print('■eval1 Evaluation:■')
                            # do_eval_cer(session=sess, decode_op=decode_op,
                            #             network=network, dataset=eval1_data,
                            #             eval_batch_size=batch_size)
                            # print('■eval2 Evaluation:■')
                            # do_eval_cer(session=sess, decode_op=decode_op,
                            #             network=network, dataset=eval2_data,
                            #             eval_batch_size=batch_size)
                            # print('■eval3 Evaluation:■')
                            # do_eval_cer(session=sess, decode_op=decode_op,
                            #             network=network, dataset=eval3_data,
                            #             eval_batch_size=batch_size)

                    else:
                        print('■Dev Evaluation:■')
                        fmean_epoch = do_eval_fmeasure(session=sess, decode_op=decode_op,
                                                       network=network, dataset=dev_data,
                                                       label_type=label_type,
                                                       social_signal_type=social_signal_type)
                        # error_epoch = do_eval_per(session=sess,
                        #                           per_op=per_op,
                        #                           network=network,
                        #                           dataset=dev_data,
                        #                           eval_batch_size=batch_size)

                        if fmean_epoch < fmean_best:
                            fmean_best = fmean_epoch
                            print('■■■ ↑Best Score (F-measure)↑ ■■■')

                            do_eval_fmeasure(session=sess, decode_op=decode_op,
                                             network=network, dataset=test_data,
                                             label_type=label_type,
                                             social_signal_type=social_signal_type)
                            # print('■eval1 Evaluation:■')
                            # do_eval_per(session=sess, per_op=per_op,
                            #             network=network, dataset=eval1_data,
                            #             eval_batch_size=batch_size)
                            # print('■eval2 Evaluation:■')
                            # do_eval_per(session=sess, per_op=per_op,
                            #             network=network, dataset=eval2_data,
                            #             eval_batch_size=batch_size)
                            # print('■eval3 Evaluation:■')
                            # do_eval_per(session=sess, per_op=per_op,
                            #             network=network, dataset=eval3_data,
                            #             eval_batch_size=batch_size)

                    duration_eval = time.time() - start_time_eval
                    print('Evaluation time: %.3f min' %
                          (duration_eval / 60))

                    start_time_epoch = time.time()
                    start_time_step = time.time()

            duration_train = time.time() - start_time_train
            print('Total time: %.3f hour' % (duration_train / 3600))

            # Save train & dev loss
            save_loss(csv_steps, csv_train_loss, csv_dev_loss,
                      save_path=network.model_dir)

            # Training was finished correctly
            with open(os.path.join(network.model_dir, 'complete.txt'), 'w') as f:
                f.write('')
    def check(self, encoder_type, lstm_impl, time_major=False):

        print('==================================================')
        print('  encoder_type: %s' % str(encoder_type))
        print('  lstm_impl: %s' % str(lstm_impl))
        print('  time_major: %s' % str(time_major))
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 2
            inputs, labels_char, labels_phone, inputs_seq_len = generate_data(
                label_type='multitask', model='ctc', batch_size=batch_size)

            # Define model graph
            num_classes_main = 27
            num_classes_sub = 61
            model = MultitaskCTC(
                encoder_type=encoder_type,
                input_size=inputs[0].shape[1],
                num_units=256,
                num_layers_main=2,
                num_layers_sub=1,
                num_classes_main=num_classes_main,
                num_classes_sub=num_classes_sub,
                main_task_weight=0.8,
                lstm_impl=lstm_impl,
                parameter_init=0.1,
                clip_grad_norm=5.0,
                clip_activation=50,
                num_proj=256,
                weight_decay=1e-8,
                # bottleneck_dim=50,
                bottleneck_dim=None,
                time_major=time_major)

            # Define placeholders
            model.create_placeholders()
            learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

            # Add to the graph each operation
            loss_op, logits_main, logits_sub = model.compute_loss(
                model.inputs_pl_list[0], model.labels_pl_list[0],
                model.labels_sub_pl_list[0], model.inputs_seq_len_pl_list[0],
                model.keep_prob_pl_list[0])
            train_op = model.train(loss_op,
                                   optimizer='adam',
                                   learning_rate=learning_rate_pl)
            decode_op_main, decode_op_sub = model.decoder(
                logits_main,
                logits_sub,
                model.inputs_seq_len_pl_list[0],
                beam_width=20)
            ler_op_main, ler_op_sub = model.compute_ler(
                decode_op_main, decode_op_sub, model.labels_pl_list[0],
                model.labels_sub_pl_list[0])

            # Define learning rate controller
            learning_rate = 1e-3
            lr_controller = Controller(learning_rate_init=learning_rate,
                                       decay_start_epoch=20,
                                       decay_rate=0.9,
                                       decay_patient_epoch=5,
                                       lower_better=True)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()), "{:,}".format(
                      total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                model.inputs_pl_list[0]:
                inputs,
                model.labels_pl_list[0]:
                list2sparsetensor(labels_char, padded_value=-1),
                model.labels_sub_pl_list[0]:
                list2sparsetensor(labels_phone, padded_value=-1),
                model.inputs_seq_len_pl_list[0]:
                inputs_seq_len,
                model.keep_prob_pl_list[0]:
                0.9,
                learning_rate_pl:
                learning_rate
            }

            idx2phone = Idx2phone(map_file_path='./phone61.txt')

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Wrapper for tfdbg
                # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

                # Train model
                max_steps = 1000
                start_time_step = time.time()
                for step in range(max_steps):

                    # Compute loss
                    _, loss_train = sess.run([train_op, loss_op],
                                             feed_dict=feed_dict)

                    # Gradient check
                    # grads = sess.run(model.clipped_grads,
                    #                  feed_dict=feed_dict)
                    # for grad in grads:
                    #     print(np.max(grad))

                    if (step + 1) % 10 == 0:
                        # Change to evaluation mode
                        feed_dict[model.keep_prob_pl_list[0]] = 1.0

                        # Compute accuracy
                        ler_train_char, ler_train_phone = sess.run(
                            [ler_op_main, ler_op_sub], feed_dict=feed_dict)

                        duration_step = time.time() - start_time_step
                        print(
                            'Step %d: loss = %.3f / cer = %.3f / per = %.3f (%.3f sec) / lr = %.5f'
                            % (step + 1, loss_train, ler_train_char,
                               ler_train_phone, duration_step, learning_rate))
                        start_time_step = time.time()

                        # Visualize
                        labels_pred_char_st, labels_pred_phone_st = sess.run(
                            [decode_op_main, decode_op_sub],
                            feed_dict=feed_dict)
                        labels_pred_char = sparsetensor2list(
                            labels_pred_char_st, batch_size=batch_size)
                        labels_pred_phone = sparsetensor2list(
                            labels_pred_phone_st, batch_size=batch_size)

                        print('Character')
                        try:
                            print('  Ref: %s' % idx2alpha(labels_char[0]))
                            print('  Hyp: %s' % idx2alpha(labels_pred_char[0]))
                        except IndexError:
                            print('Character')
                            print('  Ref: %s' % idx2alpha(labels_char[0]))
                            print('  Hyp: %s' % '')

                        print('Phone')
                        try:
                            print('  Ref: %s' % idx2phone(labels_phone[0]))
                            print('  Hyp: %s' %
                                  idx2phone(labels_pred_phone[0]))
                        except IndexError:
                            print('  Ref: %s' % idx2phone(labels_phone[0]))
                            print('  Hyp: %s' % '')
                            # NOTE: This is for no prediction
                        print('-' * 30)

                        if ler_train_char < 0.1:
                            print('Modle is Converged.')
                            break

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=step,
                            value=ler_train_char)
                        feed_dict[learning_rate_pl] = learning_rate
示例#5
0
    def check_training(self, label_type):
        print('----- ' + label_type + ' -----')
        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 1
            inputs, att_labels, inputs_seq_len, att_labels_seq_len, ctc_labels_st = generate_data(
                label_type=label_type,
                model='joint_ctc_attention',
                batch_size=batch_size)

            # Define model graph
            att_num_classes = 26 + 2 if label_type == 'character' else 61 + 2
            ctc_num_classes = 26 if label_type == 'character' else 61
            # model = load(model_type=model_type)
            network = JointCTCAttention(input_size=inputs[0].shape[1],
                                        encoder_num_unit=256,
                                        encoder_num_layer=2,
                                        attention_dim=128,
                                        attention_type='content',
                                        decoder_num_unit=256,
                                        decoder_num_layer=1,
                                        embedding_dim=20,
                                        att_num_classes=att_num_classes,
                                        ctc_num_classes=ctc_num_classes,
                                        att_task_weight=0.5,
                                        sos_index=att_num_classes - 2,
                                        eos_index=att_num_classes - 1,
                                        max_decode_length=50,
                                        attention_weights_tempareture=1.0,
                                        logits_tempareture=1.0,
                                        parameter_init=0.1,
                                        clip_grad=5.0,
                                        clip_activation_encoder=50,
                                        clip_activation_decoder=50,
                                        dropout_ratio_input=0.9,
                                        dropout_ratio_hidden=0.9,
                                        dropout_ratio_output=1.0,
                                        weight_decay=1e-8,
                                        beam_width=1,
                                        time_major=False)

            # Define placeholders
            network.create_placeholders()
            learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

            # Add to the graph each operation
            loss_op, att_logits, ctc_logits, decoder_outputs_train, decoder_outputs_infer = network.compute_loss(
                network.inputs_pl_list[0], network.att_labels_pl_list[0],
                network.inputs_seq_len_pl_list[0],
                network.att_labels_seq_len_pl_list[0],
                network.ctc_labels_pl_list[0],
                network.keep_prob_input_pl_list[0],
                network.keep_prob_hidden_pl_list[0],
                network.keep_prob_output_pl_list[0])
            train_op = network.train(loss_op,
                                     optimizer='adam',
                                     learning_rate=learning_rate_pl)
            decode_op_train, decode_op_infer = network.decoder(
                decoder_outputs_train, decoder_outputs_infer)
            ler_op = network.compute_ler(network.att_labels_st_true_pl,
                                         network.att_labels_st_pred_pl)

            # Define learning rate controller
            learning_rate = 1e-3
            lr_controller = Controller(learning_rate_init=learning_rate,
                                       decay_start_epoch=10,
                                       decay_rate=0.99,
                                       decay_patient_epoch=5,
                                       lower_better=True)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()), "{:,}".format(
                      total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                network.inputs_pl_list[0]: inputs,
                network.att_labels_pl_list[0]: att_labels,
                network.inputs_seq_len_pl_list[0]: inputs_seq_len,
                network.att_labels_seq_len_pl_list[0]: att_labels_seq_len,
                network.ctc_labels_pl_list[0]: ctc_labels_st,
                network.keep_prob_input_pl_list[0]:
                network.dropout_ratio_input,
                network.keep_prob_hidden_pl_list[0]:
                network.dropout_ratio_hidden,
                network.keep_prob_output_pl_list[0]:
                network.dropout_ratio_output,
                learning_rate_pl: learning_rate
            }

            map_file_path = '../../experiments/timit/metrics/mapping_files/attention/phone61_to_num.txt'

            with tf.Session() as sess:

                # Initialize parameters
                sess.run(init_op)

                # Wrapper for tfdbg
                # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

                # Train model
                max_steps = 400
                start_time_global = time.time()
                start_time_step = time.time()
                ler_train_pre = 1
                not_improved_count = 0
                for step in range(max_steps):

                    # Compute loss
                    _, loss_train = sess.run([train_op, loss_op],
                                             feed_dict=feed_dict)

                    # Gradient check
                    # grads = sess.run(network.clipped_grads,
                    #                  feed_dict=feed_dict)
                    # for grad in grads:
                    #     print(np.max(grad))

                    if (step + 1) % 10 == 0:
                        # Change to evaluation mode
                        feed_dict[network.keep_prob_input_pl_list[0]] = 1.0
                        feed_dict[network.keep_prob_hidden_pl_list[0]] = 1.0
                        feed_dict[network.keep_prob_output_pl_list[0]] = 1.0

                        # Predict class ids
                        predicted_ids_train, predicted_ids_infer = sess.run(
                            [decode_op_train, decode_op_infer],
                            feed_dict=feed_dict)

                        # Compute accuracy
                        try:
                            feed_dict_ler = {
                                network.att_labels_st_true_pl:
                                list2sparsetensor(att_labels, padded_value=27),
                                network.att_labels_st_pred_pl:
                                list2sparsetensor(predicted_ids_infer,
                                                  padded_value=27)
                            }
                            ler_train = sess.run(ler_op,
                                                 feed_dict=feed_dict_ler)
                        except ValueError:
                            ler_train = 1

                        duration_step = time.time() - start_time_step
                        print(
                            'Step %d: loss = %.3f / ler = %.4f (%.3f sec) / lr = %.5f'
                            % (step + 1, loss_train, ler_train, duration_step,
                               learning_rate))
                        start_time_step = time.time()

                        # Visualize
                        if label_type == 'character':
                            print('True            : %s' %
                                  idx2alpha(att_labels[0]))
                            print('Pred (Training) : <%s' %
                                  idx2alpha(predicted_ids_train[0]))
                            print('Pred (Inference): <%s' %
                                  idx2alpha(predicted_ids_infer[0]))
                        else:
                            print('True            : %s' %
                                  idx2phone(att_labels[0], map_file_path))
                            print('Pred (Training) : < %s' % idx2phone(
                                predicted_ids_train[0], map_file_path))
                            print('Pred (Inference): < %s' % idx2phone(
                                predicted_ids_infer[0], map_file_path))

                        if ler_train >= ler_train_pre:
                            not_improved_count += 1
                        else:
                            not_improved_count = 0
                        if not_improved_count >= 10:
                            print('Model is Converged.')
                            break
                        ler_train_pre = ler_train

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=step,
                            value=ler_train)
                        feed_dict[learning_rate_pl] = learning_rate

                duration_global = time.time() - start_time_global
                print('Total time: %.3f sec' % (duration_global))
def do_train(model, params):
    """Run training. If target labels are phone, the model is evaluated by PER
    with 39 phones.
    Args:
        model: the model to train
        params (dict): A dictionary of parameters
    """
    # Load dataset
    train_data = Dataset(data_type='train',
                         label_type=params['label_type'],
                         batch_size=params['batch_size'],
                         eos_index=params['eos_index'],
                         max_epoch=params['num_epoch'],
                         splice=params['splice'],
                         num_stack=params['num_stack'],
                         num_skip=params['num_skip'],
                         sort_utt=True)
    dev_data = Dataset(data_type='dev',
                       label_type=params['label_type'],
                       batch_size=params['batch_size'],
                       eos_index=params['eos_index'],
                       splice=params['splice'],
                       num_stack=params['num_stack'],
                       num_skip=params['num_skip'],
                       sort_utt=False)
    if 'char' in params['label_type']:
        test_data = Dataset(data_type='test',
                            label_type=params['label_type'],
                            batch_size=1,
                            eos_index=params['eos_index'],
                            splice=params['splice'],
                            num_stack=params['num_stack'],
                            num_skip=params['num_skip'],
                            sort_utt=False)
    else:
        test_data = Dataset(data_type='test',
                            label_type='phone39',
                            batch_size=1,
                            eos_index=params['eos_index'],
                            splice=params['splice'],
                            num_stack=params['num_stack'],
                            num_skip=params['num_skip'],
                            sort_utt=False)
    # TODO(hirofumi): add frame_stacking and splice

    # Tell TensorFlow that the model will be built into the default graph
    with tf.Graph().as_default():

        # Define placeholders
        model.create_placeholders()
        learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

        # Add to the graph each operation (including model definition)
        loss_op, att_logits, ctc_logits, decoder_outputs_train, decoder_outputs_infer = model.compute_loss(
            model.inputs_pl_list[0], model.att_labels_pl_list[0],
            model.inputs_seq_len_pl_list[0],
            model.att_labels_seq_len_pl_list[0], model.ctc_labels_pl_list[0],
            model.keep_prob_input_pl_list[0],
            model.keep_prob_hidden_pl_list[0],
            model.keep_prob_output_pl_list[0])
        train_op = model.train(loss_op,
                               optimizer=params['optimizer'],
                               learning_rate=learning_rate_pl)
        _, decode_op_infer = model.decoder(decoder_outputs_train,
                                           decoder_outputs_infer,
                                           decode_type='greedy',
                                           beam_width=20)
        ler_op = model.compute_ler(model.att_labels_st_true_pl,
                                   model.att_labels_st_pred_pl)

        # Define learning rate controller
        lr_controller = Controller(
            learning_rate_init=params['learning_rate'],
            decay_start_epoch=params['decay_start_epoch'],
            decay_rate=params['decay_rate'],
            decay_patient_epoch=params['decay_patient_epoch'],
            lower_better=True)

        # Build the summary tensor based on the TensorFlow collection of
        # summaries
        summary_train = tf.summary.merge(model.summaries_train)
        summary_dev = tf.summary.merge(model.summaries_dev)

        # Add the variable initializer operation
        init_op = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints
        saver = tf.train.Saver(max_to_keep=None)

        # Count total param
        parameters_dict, total_parameters = count_total_parameters(
            tf.trainable_variables())
        for parameter_name in sorted(parameters_dict.keys()):
            print("%s %d" % (parameter_name, parameters_dict[parameter_name]))
        print("Total %d variables, %s M param" %
              (len(parameters_dict.keys()), "{:,}".format(
                  total_parameters / 1000000)))

        csv_steps, csv_loss_train, csv_loss_dev = [], [], []
        csv_ler_train, csv_ler_dev = [], []
        # Create a session for running operation on the graph
        with tf.Session() as sess:

            # Instantiate a SummaryWriter to output summaries and the graph
            summary_writer = tf.summary.FileWriter(model.save_path, sess.graph)

            # Initialize param
            sess.run(init_op)

            # Train model
            start_time_train = time.time()
            start_time_epoch = time.time()
            start_time_step = time.time()
            ler_dev_best = 1
            learning_rate = float(params['learning_rate'])
            for step, (data, is_new_epoch) in enumerate(train_data):

                # Create feed dictionary for next mini batch (train)
                inputs, att_labels_train, ctc_labels, inputs_seq_len, att_labels_seq_len, _ = data
                feed_dict_train = {
                    model.inputs_pl_list[0]:
                    inputs,
                    model.att_labels_pl_list[0]:
                    att_labels_train,
                    model.inputs_seq_len_pl_list[0]:
                    inputs_seq_len,
                    model.att_labels_seq_len_pl_list[0]:
                    att_labels_seq_len,
                    model.ctc_labels_pl_list[0]:
                    list2sparsetensor(
                        ctc_labels, padded_value=train_data.ctc_padded_value),
                    model.keep_prob_input_pl_list[0]:
                    params['dropout_input'],
                    model.keep_prob_hidden_pl_list[0]:
                    params['dropout_hidden'],
                    model.keep_prob_output_pl_list[0]:
                    params['dropout_output'],
                    learning_rate_pl:
                    learning_rate
                }

                # Update param
                sess.run(train_op, feed_dict=feed_dict_train)

                if (step + 1) % params['print_step'] == 0:

                    # Create feed dictionary for next mini batch (dev)
                    (inputs, att_labels_dev, ctc_labels, inputs_seq_len,
                     att_labels_seq_len, _), _ = dev_data().next()
                    feed_dict_dev = {
                        model.inputs_pl_list[0]:
                        inputs,
                        model.att_labels_pl_list[0]:
                        att_labels_dev,
                        model.inputs_seq_len_pl_list[0]:
                        inputs_seq_len,
                        model.att_labels_seq_len_pl_list[0]:
                        att_labels_seq_len,
                        model.ctc_labels_pl_list[0]:
                        list2sparsetensor(
                            ctc_labels,
                            padded_value=dev_data.ctc_padded_value),
                        model.keep_prob_input_pl_list[0]:
                        1.0,
                        model.keep_prob_hidden_pl_list[0]:
                        1.0,
                        model.keep_prob_output_pl_list[0]:
                        1.0
                    }

                    # Compute loss
                    loss_train = sess.run(loss_op, feed_dict=feed_dict_train)
                    loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev)
                    csv_steps.append(step)
                    csv_loss_train.append(loss_train)
                    csv_loss_dev.append(loss_dev)

                    # Change to evaluation mode
                    feed_dict_train[model.keep_prob_input_pl_list[0]] = 1.0
                    feed_dict_train[model.keep_prob_hidden_pl_list[0]] = 1.0
                    feed_dict_train[model.keep_prob_output_pl_list[0]] = 1.0

                    # Predict class ids & update event files
                    predicted_ids_train, summary_str_train = sess.run(
                        [decode_op_infer, summary_train],
                        feed_dict=feed_dict_train)
                    predicted_ids_dev, summary_str_dev = sess.run(
                        [decode_op_infer, summary_dev],
                        feed_dict=feed_dict_dev)
                    summary_writer.add_summary(summary_str_train, step + 1)
                    summary_writer.add_summary(summary_str_dev, step + 1)
                    summary_writer.flush()

                    # Convert to sparsetensor to compute LER
                    feed_dict_ler_train = {
                        model.att_labels_true_st:
                        list2sparsetensor(att_labels_train,
                                          padded_value=params['eos_index']),
                        model.att_labels_st_pred_pl:
                        list2sparsetensor(predicted_ids_train,
                                          padded_value=params['eos_index'])
                    }
                    feed_dict_ler_dev = {
                        model.att_labels_true_st:
                        list2sparsetensor(att_labels_dev,
                                          padded_value=params['eos_index']),
                        model.att_labels_st_pred_pl:
                        list2sparsetensor(predicted_ids_dev,
                                          padded_value=params['eos_index'])
                    }

                    # Compute accuracy
                    ler_train = sess.run(ler_op, feed_dict=feed_dict_ler_train)
                    ler_dev = sess.run(ler_op, feed_dict=feed_dict_ler_dev)
                    csv_ler_train.append(ler_train)
                    csv_ler_dev.append(ler_dev)

                    duration_step = time.time() - start_time_step
                    print(
                        "Step %d (epoch: %.3f): loss = %.3f (%.3f) / ler = %.3f (%.3f) / lr = %.5f (%.3f min)"
                        % (step + 1, train_data.epoch_detail, loss_train,
                           loss_dev, ler_train, ler_dev, learning_rate,
                           duration_step / 60))
                    # sys.stdout.flush()
                    start_time_step = time.time()

                # Save checkpoint and evaluate model per epoch
                if is_new_epoch:
                    duration_epoch = time.time() - start_time_epoch
                    print('-----EPOCH:%d (%.3f min)-----' %
                          (train_data.epoch, duration_epoch / 60))

                    # Save fugure of loss & ler
                    plot_loss(csv_loss_train,
                              csv_loss_dev,
                              csv_steps,
                              save_path=model.save_path)
                    plot_ler(csv_ler_train,
                             csv_ler_dev,
                             csv_steps,
                             label_type=params['label_type'],
                             save_path=model.save_path)

                    if train_data.epoch >= params['eval_start_epoch']:
                        start_time_eval = time.time()
                        if 'char' in params['label_type']:
                            print('=== Dev Data Evaluation ===')
                            ler_dev_epoch = do_eval_cer(
                                session=sess,
                                decode_op=decode_op_infer,
                                model=model,
                                dataset=dev_data,
                                eval_batch_size=1)
                            print('  CER: %f %%' % (ler_dev_epoch * 100))

                            if ler_dev_epoch < ler_dev_best:
                                ler_dev_best = ler_dev_epoch
                                print('■■■ ↑Best Score (CER)↑ ■■■')

                                # Save model only when best accuracy is
                                # obtained (check point)
                                checkpoint_file = join(model.save_path,
                                                       'model.ckpt')
                                save_path = saver.save(
                                    sess,
                                    checkpoint_file,
                                    global_step=train_data.epoch)
                                print("Model saved in file: %s" % save_path)

                                print('=== Test Data Evaluation ===')
                                ler_test = do_eval_cer(
                                    session=sess,
                                    decode_op=decode_op_infer,
                                    model=model,
                                    dataset=test_data,
                                    eval_batch_size=1)
                                print('  CER: %f %%' % (ler_test * 100))

                        else:
                            print('=== Dev Data Evaluation ===')
                            ler_dev_epoch = do_eval_per(
                                session=sess,
                                decode_op=decode_op_infer,
                                per_op=ler_op,
                                model=model,
                                dataset=dev_data,
                                label_type=params['label_type'],
                                eval_batch_size=1)
                            print('  PER: %f %%' % (ler_dev_epoch * 100))

                            if ler_dev_epoch < ler_dev_best:
                                ler_dev_best = ler_dev_epoch
                                print('■■■ ↑Best Score (PER)↑ ■■■')

                                # Save model only when best accuracy is
                                # obtained (check point)
                                checkpoint_file = join(model.save_path,
                                                       'model.ckpt')
                                save_path = saver.save(
                                    sess,
                                    checkpoint_file,
                                    global_step=train_data.epoch)
                                print("Model saved in file: %s" % save_path)

                                print('=== Test Data Evaluation ===')
                                ler_test = do_eval_per(
                                    session=sess,
                                    decode_op=decode_op_infer,
                                    per_op=ler_op,
                                    model=model,
                                    dataset=test_data,
                                    label_type=params['label_type'],
                                    eval_batch_size=1)
                                print('  PER: %f %%' % (ler_test * 100))

                        duration_eval = time.time() - start_time_eval
                        print('Evaluation time: %.3f min' %
                              (duration_eval / 60))

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=train_data.epoch,
                            value=ler_dev_epoch)

                    start_time_epoch = time.time()

            duration_train = time.time() - start_time_train
            print('Total time: %.3f hour' % (duration_train / 3600))

            # Training was finished correctly
            with open(join(model.save_path, 'complete.txt'), 'w') as f:
                f.write('')
示例#7
0
    def check(self, encoder_type, lstm_impl=None, time_major=False):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  lstm_impl: %s' % lstm_impl)
        print('  time_major: %s' % time_major)
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 4
            splice = 5 if encoder_type in [
                'vgg_blstm', 'vgg_lstm', 'vgg_wang', 'resnet_wang',
                'cldnn_wang', 'cnn_zhang'
            ] else 1
            num_stack = 2
            inputs, _, inputs_seq_len = generate_data(label_type='character',
                                                      model='ctc',
                                                      batch_size=batch_size,
                                                      num_stack=num_stack,
                                                      splice=splice)
            frame_num, input_size = inputs[0].shape

            # Define model graph
            if encoder_type in ['blstm', 'lstm']:
                encoder = load(encoder_type)(num_units=256,
                                             num_proj=None,
                                             num_layers=5,
                                             lstm_impl=lstm_impl,
                                             use_peephole=True,
                                             parameter_init=0.1,
                                             clip_activation=5,
                                             time_major=time_major)
            elif encoder_type in ['bgru', 'gru']:
                encoder = load(encoder_type)(num_units=256,
                                             num_layers=5,
                                             parameter_init=0.1,
                                             time_major=time_major)
            elif encoder_type in ['vgg_blstm', 'vgg_lstm', 'cldnn_wang']:
                encoder = load(encoder_type)(input_size=input_size // splice //
                                             num_stack,
                                             splice=splice,
                                             num_stack=num_stack,
                                             num_units=256,
                                             num_proj=None,
                                             num_layers=5,
                                             lstm_impl=lstm_impl,
                                             use_peephole=True,
                                             parameter_init=0.1,
                                             clip_activation=5,
                                             time_major=time_major)
            elif encoder_type in ['multitask_blstm', 'multitask_lstm']:
                encoder = load(encoder_type)(num_units=256,
                                             num_proj=None,
                                             num_layers_main=5,
                                             num_layers_sub=3,
                                             lstm_impl=lstm_impl,
                                             use_peephole=True,
                                             parameter_init=0.1,
                                             clip_activation=5,
                                             time_major=time_major)
            elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                encoder = load(encoder_type)(input_size=input_size // splice //
                                             num_stack,
                                             splice=splice,
                                             num_stack=num_stack,
                                             parameter_init=0.1,
                                             time_major=time_major)
                # NOTE: topology is pre-defined
            else:
                raise NotImplementedError

            # Create placeholders
            inputs_pl = tf.placeholder(tf.float32,
                                       shape=[None, None, input_size],
                                       name='inputs')
            inputs_seq_len_pl = tf.placeholder(tf.int32,
                                               shape=[None],
                                               name='inputs_seq_len')
            keep_prob_pl = tf.placeholder(tf.float32, name='keep_prob')

            # operation for forward computation
            if encoder_type in ['multitask_blstm', 'multitask_lstm']:
                hidden_states_op, final_state_op, hidden_states_sub_op, final_state_sub_op = encoder(
                    inputs=inputs_pl,
                    inputs_seq_len=inputs_seq_len_pl,
                    keep_prob=keep_prob_pl,
                    is_training=True)
            else:
                hidden_states_op, final_state_op = encoder(
                    inputs=inputs_pl,
                    inputs_seq_len=inputs_seq_len_pl,
                    keep_prob=keep_prob_pl,
                    is_training=True)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()), "{:,}".format(
                      total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                inputs_pl: inputs,
                inputs_seq_len_pl: inputs_seq_len,
                keep_prob_pl: 0.9
            }

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Make prediction
                if encoder_type in ['multitask_blstm', 'multitask_lstm']:
                    encoder_outputs, final_state, hidden_states_sub, final_state_sub = sess.run(
                        [
                            hidden_states_op, final_state_op,
                            hidden_states_sub_op, final_state_sub_op
                        ],
                        feed_dict=feed_dict)
                elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                    encoder_outputs = sess.run(hidden_states_op,
                                               feed_dict=feed_dict)
                else:
                    encoder_outputs, final_state = sess.run(
                        [hidden_states_op, final_state_op],
                        feed_dict=feed_dict)

                # Convert always to batch-major
                if time_major:
                    encoder_outputs = encoder_outputs.transpose(1, 0, 2)

                if encoder_type in [
                        'blstm', 'bgru', 'vgg_blstm', 'multitask_blstm',
                        'cldnn_wang'
                ]:
                    if encoder_type != 'cldnn_wang':
                        self.assertEqual(
                            (batch_size, frame_num, encoder.num_units * 2),
                            encoder_outputs.shape)

                    if encoder_type != 'bgru':
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[0].c.shape)
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[0].h.shape)
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[1].c.shape)
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[1].h.shape)

                        if encoder_type == 'multitask_blstm':
                            self.assertEqual(
                                (batch_size, frame_num, encoder.num_units * 2),
                                hidden_states_sub.shape)
                            self.assertEqual((batch_size, encoder.num_units),
                                             final_state_sub[0].c.shape)
                            self.assertEqual((batch_size, encoder.num_units),
                                             final_state_sub[0].h.shape)
                            self.assertEqual((batch_size, encoder.num_units),
                                             final_state_sub[1].c.shape)
                            self.assertEqual((batch_size, encoder.num_units),
                                             final_state_sub[1].h.shape)
                    else:
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[0].shape)
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[1].shape)

                elif encoder_type in [
                        'lstm', 'gru', 'vgg_lstm', 'multitask_lstm'
                ]:
                    self.assertEqual(
                        (batch_size, frame_num, encoder.num_units),
                        encoder_outputs.shape)

                    if encoder_type != 'gru':
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[0].c.shape)
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[0].h.shape)

                        if encoder_type == 'multitask_lstm':
                            self.assertEqual(
                                (batch_size, frame_num, encoder.num_units),
                                hidden_states_sub.shape)
                            self.assertEqual((batch_size, encoder.num_units),
                                             final_state_sub[0].c.shape)
                            self.assertEqual((batch_size, encoder.num_units),
                                             final_state_sub[0].h.shape)
                    else:
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[0].shape)

                elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                    self.assertEqual(3, len(encoder_outputs.shape))
                    self.assertEqual((batch_size, frame_num),
                                     encoder_outputs.shape[:2])
示例#8
0
def do_train(network, optimizer, learning_rate, batch_size, epoch_num, label_type, num_stack, num_skip):
    """Run training.
    Args:
        network: network to train
        optimizer: string, the name of optimizer. ex.) adam, rmsprop
        learning_rate: initial learning rate
        batch_size: size of mini batch
        epoch_num: epoch num to train
        label_type: phone39 or phone48 or phone61 or character
        num_stack: int, the number of frames to stack
        num_skip: int, the number of frames to skip
    """
    # Load dataset
    train_data = DataSet(data_type='train', label_type=label_type,
                         num_stack=num_stack, num_skip=num_skip,
                         is_sorted=True)
    if label_type == 'character':
        dev_data = DataSet(data_type='dev', label_type='character',
                           num_stack=num_stack, num_skip=num_skip,
                           is_sorted=False)
        test_data = DataSet(data_type='test', label_type='character',
                            num_stack=num_stack, num_skip=num_skip,
                            is_sorted=False)
    else:
        dev_data = DataSet(data_type='dev', label_type='phone39',
                           num_stack=num_stack, num_skip=num_skip,
                           is_sorted=False)
        test_data = DataSet(data_type='test', label_type='phone39',
                            num_stack=num_stack, num_skip=num_skip,
                            is_sorted=False)

    # Tell TensorFlow that the model will be built into the default graph
    with tf.Graph().as_default():

        # Define model
        network.define()
        # NOTE: define model under tf.Graph()

        # Add to the graph each operation
        loss_op = network.loss()
        train_op = network.train(optimizer=optimizer,
                                 learning_rate_init=learning_rate,
                                 is_scheduled=False)
        decode_op = network.decoder(decode_type='beam_search',
                                    beam_width=20)
        per_op = network.ler(decode_op)

        # Build the summary tensor based on the TensorFlow collection of
        # summaries
        summary_train = tf.summary.merge(network.summaries_train)
        summary_dev = tf.summary.merge(network.summaries_dev)

        # Add the variable initializer operation
        init_op = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints
        saver = tf.train.Saver(max_to_keep=None)

        # Count total parameters
        parameters_dict, total_parameters = count_total_parameters(
            tf.trainable_variables())
        for parameter_name in sorted(parameters_dict.keys()):
            print("%s %d" % (parameter_name, parameters_dict[parameter_name]))
        print("Total %d variables, %s M parameters" %
              (len(parameters_dict.keys()), "{:,}".format(total_parameters / 1000000)))

        csv_steps = []
        csv_train_loss = []
        csv_dev_loss = []
        # Create a session for running operation on the graph
        with tf.Session() as sess:

            # Instantiate a SummaryWriter to output summaries and the graph
            summary_writer = tf.summary.FileWriter(
                network.model_dir, sess.graph)

            # Initialize parameters
            sess.run(init_op)

            # Train model
            iter_per_epoch = int(train_data.data_num / batch_size)
            if (train_data.data_num / batch_size) != int(train_data.data_num / batch_size):
                iter_per_epoch += 1
            max_steps = iter_per_epoch * epoch_num
            start_time_train = time.time()
            start_time_epoch = time.time()
            start_time_step = time.time()
            error_best = 1
            for step in range(max_steps):

                # Create feed dictionary for next mini batch (train)
                inputs, labels, seq_len, _ = train_data.next_batch(
                    batch_size=batch_size)
                indices, values, dense_shape = list2sparsetensor(labels)
                feed_dict_train = {
                    network.inputs_pl: inputs,
                    network.label_indices_pl: indices,
                    network.label_values_pl: values,
                    network.label_shape_pl: dense_shape,
                    network.seq_len_pl: seq_len,
                    network.keep_prob_input_pl: network.dropout_ratio_input,
                    network.keep_prob_hidden_pl: network.dropout_ratio_hidden,
                    network.lr_pl: learning_rate
                }

                # Create feed dictionary for next mini batch (dev)
                inputs, labels, seq_len, _ = dev_data.next_batch(
                    batch_size=batch_size)
                indices, values, dense_shape = list2sparsetensor(labels)
                feed_dict_dev = {
                    network.inputs_pl: inputs,
                    network.label_indices_pl: indices,
                    network.label_values_pl: values,
                    network.label_shape_pl: dense_shape,
                    network.seq_len_pl: seq_len,
                    network.keep_prob_input_pl: network.dropout_ratio_input,
                    network.keep_prob_hidden_pl: network.dropout_ratio_hidden
                }

                # Update parameters & compute loss
                _, loss_train = sess.run(
                    [train_op, loss_op], feed_dict=feed_dict_train)
                loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev)
                csv_steps.append(step)
                csv_train_loss.append(loss_train)
                csv_dev_loss.append(loss_dev)

                if (step + 1) % 10 == 0:

                    # Change feed dict for evaluation
                    feed_dict_train[network.keep_prob_input_pl] = 1.0
                    feed_dict_train[network.keep_prob_hidden_pl] = 1.0
                    feed_dict_dev[network.keep_prob_input_pl] = 1.0
                    feed_dict_dev[network.keep_prob_hidden_pl] = 1.0

                    # Compute accuracy & update event file
                    ler_train, summary_str_train = sess.run([per_op, summary_train],
                                                            feed_dict=feed_dict_train)
                    ler_dev, summary_str_dev, labels_st = sess.run([per_op, summary_dev, decode_op],
                                                                   feed_dict=feed_dict_dev)
                    summary_writer.add_summary(summary_str_train, step + 1)
                    summary_writer.add_summary(summary_str_dev, step + 1)
                    summary_writer.flush()

                    duration_step = time.time() - start_time_step
                    print('Step %d: loss = %.3f (%.3f) / ler = %.4f (%.4f) (%.3f min)' %
                          (step + 1, loss_train, loss_dev, ler_train, ler_dev, duration_step / 60))
                    sys.stdout.flush()
                    start_time_step = time.time()

                # Save checkpoint and evaluate model per epoch
                if (step + 1) % iter_per_epoch == 0 or (step + 1) == max_steps:
                    duration_epoch = time.time() - start_time_epoch
                    epoch = (step + 1) // iter_per_epoch
                    print('-----EPOCH:%d (%.3f min)-----' %
                          (epoch, duration_epoch / 60))

                    # Save model (check point)
                    checkpoint_file = join(network.model_dir, 'model.ckpt')
                    save_path = saver.save(
                        sess, checkpoint_file, global_step=epoch)
                    print("Model saved in file: %s" % save_path)

                    if epoch >= 10:
                        start_time_eval = time.time()
                        if label_type == 'character':
                            print('■Dev Data Evaluation:■')
                            error_epoch = do_eval_cer(session=sess,
                                                      decode_op=decode_op,
                                                      network=network,
                                                      dataset=dev_data,
                                                      eval_batch_size=1)

                            if error_epoch < error_best:
                                error_best = error_epoch
                                print('■■■ ↑Best Score (CER)↑ ■■■')

                                print('■Test Data Evaluation:■')
                                do_eval_cer(session=sess, decode_op=decode_op,
                                            network=network, dataset=test_data,
                                            eval_batch_size=1)

                        else:
                            print('■Dev Data Evaluation:■')
                            error_epoch = do_eval_per(session=sess,
                                                      decode_op=decode_op,
                                                      per_op=per_op,
                                                      network=network,
                                                      dataset=dev_data,
                                                      label_type=label_type,
                                                      eval_batch_size=1)

                            if error_epoch < error_best:
                                error_best = error_epoch
                                print('■■■ ↑Best Score (PER)↑ ■■■')

                                print('■Test Data Evaluation:■')
                                do_eval_per(session=sess, decode_op=decode_op,
                                            per_op=per_op, network=network,
                                            dataset=test_data,
                                            label_type=label_type,
                                            eval_batch_size=1)

                        duration_eval = time.time() - start_time_eval
                        print('Evaluation time: %.3f min' %
                              (duration_eval / 60))

                    start_time_epoch = time.time()
                    start_time_step = time.time()

            duration_train = time.time() - start_time_train
            print('Total time: %.3f hour' % (duration_train / 3600))

            # Save train & dev loss
            save_loss(csv_steps, csv_train_loss, csv_dev_loss,
                      save_path=network.model_dir)

            # Training was finished correctly
            with open(join(network.model_dir, 'complete.txt'), 'w') as f:
                f.write('')
def do_train(model, params):
    """Run training. If target labels are phone, the model is evaluated by PER
    with 39 phones.
    Args:
        model: the model to train
        params (dict): A dictionary of parameters
    """
    map_file_path_train = '../metrics/mapping_files/' + \
        params['label_type'] + '.txt'
    if 'phone' in params['label_type']:
        map_file_path_eval = '../metrics/mapping_files/phone39.txt'
    else:
        map_file_path_eval = '../metrics/mapping_files/' + \
            params['label_type'] + '.txt'

    # Load dataset
    train_data = Dataset(
        data_type='train', label_type=params['label_type'],
        batch_size=params['batch_size'], map_file_path=map_file_path_train,
        max_epoch=params['num_epoch'], splice=params['splice'],
        num_stack=params['num_stack'], num_skip=params['num_skip'],
        sort_utt=True, sort_stop_epoch=params['sort_stop_epoch'])
    dev_data = Dataset(
        data_type='dev', label_type=params['label_type'],
        batch_size=params['batch_size'], map_file_path=map_file_path_train,
        splice=params['splice'],
        num_stack=params['num_stack'], num_skip=params['num_skip'],
        sort_utt=False)
    if 'char' in params['label_type']:
        test_data = Dataset(
            data_type='test', label_type=params['label_type'],
            batch_size=1, map_file_path=map_file_path_eval,
            splice=params['splice'],
            num_stack=params['num_stack'], num_skip=params['num_skip'],
            sort_utt=False)
    else:
        test_data = Dataset(
            data_type='test', label_type='phone39',
            batch_size=1, map_file_path=map_file_path_eval,
            splice=params['splice'],
            num_stack=params['num_stack'], num_skip=params['num_skip'],
            sort_utt=False)

    # Tell TensorFlow that the model will be built into the default graph
    with tf.Graph().as_default():

        # Define placeholders
        model.create_placeholders()
        learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

        # Add to the graph each operation (including model definition)
        loss_op, logits, decoder_outputs_train, decoder_outputs_infer = model.compute_loss(
            model.inputs_pl_list[0],
            model.labels_pl_list[0],
            model.inputs_seq_len_pl_list[0],
            model.labels_seq_len_pl_list[0],
            model.keep_prob_encoder_pl_list[0],
            model.keep_prob_decoder_pl_list[0],
            model.keep_prob_embedding_pl_list[0])
        train_op = model.train(loss_op,
                               optimizer=params['optimizer'],
                               learning_rate=learning_rate_pl)
        _, decode_op_infer = model.decode(
            decoder_outputs_train,
            decoder_outputs_infer)
        ler_op = model.compute_ler(model.labels_st_true_pl,
                                   model.labels_st_pred_pl)

        # Define learning rate controller
        lr_controller = Controller(
            learning_rate_init=params['learning_rate'],
            decay_start_epoch=params['decay_start_epoch'],
            decay_rate=params['decay_rate'],
            decay_patient_epoch=params['decay_patient_epoch'],
            lower_better=True)

        # Build the summary tensor based on the TensorFlow collection of
        # summaries
        summary_train = tf.summary.merge(model.summaries_train)
        summary_dev = tf.summary.merge(model.summaries_dev)

        # Add the variable initializer operation
        init_op = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints
        saver = tf.train.Saver(max_to_keep=None)

        # Count total param
        parameters_dict, total_parameters = count_total_parameters(
            tf.trainable_variables())
        for parameter_name in sorted(parameters_dict.keys()):
            print("%s %d" % (parameter_name, parameters_dict[parameter_name]))
        print("Total %d variables, %s M param" %
              (len(parameters_dict.keys()),
               "{:,}".format(total_parameters / 1000000)))

        csv_steps, csv_loss_train, csv_loss_dev = [], [], []
        csv_ler_train, csv_ler_dev = [], []
        # Create a session for running operation on the graph
        with tf.Session() as sess:

            # Instantiate a SummaryWriter to output summaries and the graph
            summary_writer = tf.summary.FileWriter(
                model.save_path, sess.graph)

            # Initialize param
            sess.run(init_op)

            # Train model
            start_time_train = time.time()
            start_time_epoch = time.time()
            start_time_step = time.time()
            ler_dev_best = 1
            learning_rate = float(params['learning_rate'])
            for step, (data, is_new_epoch) in enumerate(train_data):

                # Create feed dictionary for next mini batch (train)
                inputs, labels_train, inputs_seq_len, labels_seq_len, _ = data
                feed_dict_train = {
                    model.inputs_pl_list[0]: inputs[0],
                    model.labels_pl_list[0]: labels_train[0],
                    model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
                    model.labels_seq_len_pl_list[0]: labels_seq_len[0],
                    model.keep_prob_encoder_pl_list[0]: 1 - float(params['dropout_encoder']),
                    model.keep_prob_decoder_pl_list[0]: 1 - float(params['dropout_decoder']),
                    model.keep_prob_embedding_pl_list[0]: 1 - float(params['dropout_embedding']),
                    learning_rate_pl: learning_rate
                }

                # Update parameters
                sess.run(train_op, feed_dict=feed_dict_train)

                if (step + 1) % params['print_step'] == 0:

                    # Create feed dictionary for next mini batch (dev)
                    (inputs, labels_dev, inputs_seq_len,
                     labels_seq_len, _), _ = dev_data.next()
                    feed_dict_dev = {
                        model.inputs_pl_list[0]: inputs[0],
                        model.labels_pl_list[0]: labels_dev[0],
                        model.inputs_seq_len_pl_list[0]: inputs_seq_len[0],
                        model.labels_seq_len_pl_list[0]: labels_seq_len[0],
                        model.keep_prob_encoder_pl_list[0]: 1.0,
                        model.keep_prob_decoder_pl_list[0]: 1.0,
                        model.keep_prob_embedding_pl_list[0]: 1.0
                    }

                    # Compute loss
                    loss_train = sess.run(loss_op, feed_dict=feed_dict_train)
                    loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev)
                    csv_steps.append(step)
                    csv_loss_train.append(loss_train)
                    csv_loss_dev.append(loss_dev)

                    # Change to evaluation mode
                    feed_dict_train[model.keep_prob_encoder_pl_list[0]] = 1.0
                    feed_dict_train[model.keep_prob_decoder_pl_list[0]] = 1.0
                    feed_dict_train[model.keep_prob_embedding_pl_list[0]] = 1.0

                    # Predict class ids & update even files
                    predicted_ids_train, summary_str_train = sess.run(
                        [decode_op_infer, summary_train], feed_dict=feed_dict_train)
                    predicted_ids_dev, summary_str_dev = sess.run(
                        [decode_op_infer, summary_dev], feed_dict=feed_dict_dev)
                    summary_writer.add_summary(summary_str_train, step + 1)
                    summary_writer.add_summary(summary_str_dev, step + 1)
                    summary_writer.flush()

                    # Convert to sparsetensor to compute LER
                    feed_dict_ler_train = {
                        model.labels_st_true_pl: list2sparsetensor(
                            labels_train[0], padded_value=train_data.padded_value),
                        model.labels_st_pred_pl: list2sparsetensor(
                            predicted_ids_train, padded_value=train_data.padded_value)
                    }
                    feed_dict_ler_dev = {
                        model.labels_st_true_pl: list2sparsetensor(
                            labels_dev[0], padded_value=dev_data.padded_value),
                        model.labels_st_pred_pl: list2sparsetensor(
                            predicted_ids_dev, padded_value=dev_data.padded_value)
                    }

                    # Compute accuracy
                    ler_train = sess.run(ler_op, feed_dict=feed_dict_ler_train)
                    ler_dev = sess.run(ler_op, feed_dict=feed_dict_ler_dev)
                    csv_ler_train.append(ler_train)
                    csv_ler_dev.append(ler_dev)

                    duration_step = time.time() - start_time_step
                    print("Step %d (epoch: %.3f): loss = %.3f (%.3f) / ler = %.3f (%.3f) / lr = %.5f (%.3f min)" %
                          (step + 1, train_data.epoch_detail, loss_train, loss_dev, ler_train, ler_dev,
                           learning_rate, duration_step / 60))
                    sys.stdout.flush()
                    start_time_step = time.time()

                # Save checkpoint and evaluate model per epoch
                if is_new_epoch:
                    duration_epoch = time.time() - start_time_epoch
                    print('-----EPOCH:%d (%.3f min)-----' %
                          (train_data.epoch, duration_epoch / 60))

                    # Save fugure of loss & ler
                    plot_loss(csv_loss_train, csv_loss_dev, csv_steps,
                              save_path=model.save_path)
                    plot_ler(csv_ler_train, csv_ler_dev, csv_steps,
                             label_type=params['label_type'],
                             save_path=model.save_path)

                    if train_data.epoch >= params['eval_start_epoch']:
                        start_time_eval = time.time()
                        if 'char' in params['label_type']:
                            print('=== Dev Data Evaluation ===')
                            ler_dev_epoch, wer_dev_epoch = do_eval_cer(
                                session=sess,
                                decode_op=decode_op_infer,
                                model=model,
                                dataset=dev_data,
                                label_type=params['label_type'],
                                eval_batch_size=1)
                            print('  CER: %f %%' % (ler_dev_epoch * 100))
                            print('  WER: %f %%' % (wer_dev_epoch * 100))

                            if ler_dev_epoch < ler_dev_best:
                                ler_dev_best = ler_dev_epoch
                                print('■■■ ↑Best Score (CER)↑ ■■■')

                                # Save model only when best accuracy is
                                # obtained (check point)
                                checkpoint_file = join(
                                    model.save_path, 'model.ckpt')
                                save_path = saver.save(
                                    sess, checkpoint_file, global_step=train_data.epoch)
                                print("Model saved in file: %s" % save_path)

                                print('=== Test Data Evaluation ===')
                                ler_test, wer_test = do_eval_cer(
                                    session=sess,
                                    decode_op=decode_op_infer,
                                    model=model,
                                    dataset=test_data,
                                    label_type=params['label_type'],
                                    is_test=True,
                                    eval_batch_size=1)
                                print('  CER: %f %%' % (ler_test * 100))
                                print('  WER: %f %%' % (wer_test * 100))

                        else:
                            print('=== Dev Data Evaluation ===')
                            ler_dev_epoch = do_eval_per(
                                session=sess,
                                decode_op=decode_op_infer,
                                per_op=ler_op,
                                model=model,
                                dataset=dev_data,
                                label_type=params['label_type'],
                                eval_batch_size=1)
                            print('  PER: %f %%' % (ler_dev_epoch * 100))

                            if ler_dev_epoch < ler_dev_best:
                                ler_dev_best = ler_dev_epoch
                                print('■■■ ↑Best Score (PER)↑ ■■■')

                                # Save model only when best accuracy is
                                # obtained (check point)
                                checkpoint_file = join(
                                    model.save_path, 'model.ckpt')
                                save_path = saver.save(
                                    sess, checkpoint_file, global_step=train_data.epoch)
                                print("Model saved in file: %s" % save_path)

                                print('=== Test Data Evaluation ===')
                                ler_test = do_eval_per(
                                    session=sess,
                                    decode_op=decode_op_infer,
                                    per_op=ler_op,
                                    model=model,
                                    dataset=test_data,
                                    label_type=params['label_type'],
                                    is_test=True,
                                    eval_batch_size=1)
                                print('  PER: %f %%' % (ler_test * 100))

                        duration_eval = time.time() - start_time_eval
                        print('Evaluation time: %.3f min' %
                              (duration_eval / 60))

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=train_data.epoch,
                            value=ler_dev_epoch)

                    start_time_step = time.time()
                    start_time_epoch = time.time()

            duration_train = time.time() - start_time_train
            print('Total time: %.3f hour' % (duration_train / 3600))

            # Training was finished correctly
            with open(join(model.save_path, 'complete.txt'), 'w') as f:
                f.write('')
示例#10
0
    def check_training(self,
                       encoder_type,
                       label_type,
                       lstm_impl='LSTMBlockCell',
                       save_params=False):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  label_type: %s' % label_type)
        print('  lstm_impl: %s' % lstm_impl)
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 1
            splice = 11 if encoder_type in [
                'vgg_blstm', 'vgg_lstm', 'vgg_wang', 'resnet_wang', 'cnn_zhang'
            ] else 1
            inputs, labels_true_st, inputs_seq_len = generate_data(
                label_type=label_type,
                model='ctc',
                batch_size=batch_size,
                splice=splice)
            # NOTE: input_size must be even number when using CudnnLSTM

            # Define model graph
            num_classes = 26 if label_type == 'character' else 61
            model = CTC(
                encoder_type=encoder_type,
                input_size=inputs[0].shape[-1] // splice,
                splice=splice,
                num_units=256,
                num_layers=2,
                num_classes=num_classes,
                lstm_impl=lstm_impl,
                parameter_init=0.1,
                clip_grad=5.0,
                clip_activation=50,
                num_proj=256,
                # bottleneck_dim=50,
                bottleneck_dim=None,
                weight_decay=1e-8)

            # Define placeholders
            model.create_placeholders()
            learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

            # Add to the graph each operation
            loss_op, logits = model.compute_loss(
                model.inputs_pl_list[0], model.labels_pl_list[0],
                model.inputs_seq_len_pl_list[0],
                model.keep_prob_input_pl_list[0],
                model.keep_prob_hidden_pl_list[0],
                model.keep_prob_output_pl_list[0])
            train_op = model.train(loss_op,
                                   optimizer='adam',
                                   learning_rate=learning_rate_pl)
            # NOTE: Adam does not run on CudnnLSTM
            decode_op = model.decoder(logits,
                                      model.inputs_seq_len_pl_list[0],
                                      beam_width=20)
            ler_op = model.compute_ler(decode_op, model.labels_pl_list[0])

            # Define learning rate controller
            learning_rate = 1e-3
            lr_controller = Controller(learning_rate_init=learning_rate,
                                       decay_start_epoch=10,
                                       decay_rate=0.98,
                                       decay_patient_epoch=5,
                                       lower_better=True)

            if save_params:
                # Create a saver for writing training checkpoints
                saver = tf.train.Saver(max_to_keep=None)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            if lstm_impl != 'CudnnLSTM':
                parameters_dict, total_parameters = count_total_parameters(
                    tf.trainable_variables())
                for parameter_name in sorted(parameters_dict.keys()):
                    print("%s %d" %
                          (parameter_name, parameters_dict[parameter_name]))
                print("Total %d variables, %s M parameters" %
                      (len(parameters_dict.keys()), "{:,}".format(
                          total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                model.inputs_pl_list[0]: inputs,
                model.labels_pl_list[0]: labels_true_st,
                model.inputs_seq_len_pl_list[0]: inputs_seq_len,
                model.keep_prob_input_pl_list[0]: 1.0,
                model.keep_prob_hidden_pl_list[0]: 1.0,
                model.keep_prob_output_pl_list[0]: 1.0,
                learning_rate_pl: learning_rate
            }

            idx2phone = Idx2phone(map_file_path='./phone61_ctc.txt')

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Wrapper for tfdbg
                # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

                # Train model
                max_steps = 1000
                start_time_global = time.time()
                start_time_step = time.time()
                ler_train_pre = 1
                not_improved_count = 0
                for step in range(max_steps):

                    # Compute loss
                    _, loss_train = sess.run([train_op, loss_op],
                                             feed_dict=feed_dict)

                    # Gradient check
                    # grads = sess.run(model.clipped_grads,
                    #                  feed_dict=feed_dict)
                    # for grad in grads:
                    #     print(np.max(grad))

                    if (step + 1) % 10 == 0:
                        # Change to evaluation mode
                        feed_dict[model.keep_prob_input_pl_list[0]] = 1.0
                        feed_dict[model.keep_prob_hidden_pl_list[0]] = 1.0
                        feed_dict[model.keep_prob_output_pl_list[0]] = 1.0

                        # Compute accuracy
                        ler_train = sess.run(ler_op, feed_dict=feed_dict)

                        duration_step = time.time() - start_time_step
                        print(
                            'Step %d: loss = %.3f / ler = %.3f (%.3f sec) / lr = %.5f'
                            % (step + 1, loss_train, ler_train, duration_step,
                               learning_rate))
                        start_time_step = time.time()

                        # Decode
                        labels_pred_st = sess.run(decode_op,
                                                  feed_dict=feed_dict)
                        labels_true = sparsetensor2list(labels_true_st,
                                                        batch_size=batch_size)

                        # Visualize
                        try:
                            labels_pred = sparsetensor2list(
                                labels_pred_st, batch_size=batch_size)
                            if label_type == 'character':
                                print('Ref: %s' % idx2alpha(labels_true[0]))
                                print('Hyp: %s' % idx2alpha(labels_pred[0]))
                            else:
                                print('Ref: %s' % idx2phone(labels_true[0]))
                                print('Hyp: %s' % idx2phone(labels_pred[0]))

                        except IndexError:
                            if label_type == 'character':
                                print('Ref: %s' % idx2alpha(labels_true[0]))
                                print('Hyp: %s' % '')
                            else:
                                print('Ref: %s' % idx2phone(labels_true[0]))
                                print('Hyp: %s' % '')
                            # NOTE: This is for no prediction

                        if ler_train >= ler_train_pre:
                            not_improved_count += 1
                        else:
                            not_improved_count = 0
                        if ler_train < 0.05:
                            print('Modle is Converged.')
                            if save_params:
                                # Save model (check point)
                                checkpoint_file = './model.ckpt'
                                save_path = saver.save(sess,
                                                       checkpoint_file,
                                                       global_step=1)
                                print("Model saved in file: %s" % save_path)
                            break
                        ler_train_pre = ler_train

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=step,
                            value=ler_train)
                        feed_dict[learning_rate_pl] = learning_rate

                duration_global = time.time() - start_time_global
                print('Total time: %.3f sec' % (duration_global))
def do_train(network, optimizer, learning_rate, batch_size, epoch_num,
             label_type, num_stack, num_skip):
    """Run training. If target labels are phone, the model is evaluated by PER
    with 39 phones.
    Args:
        network: network to train
        optimizer: string, the name of optimizer.
            ex.) adam, rmsprop
        learning_rate: A float value, the initial learning rate
        batch_size: int, the size of mini-batch
        epoch_num: int, the number of epochs to train
        label_type: string, phone39 or phone48 or phone61 or character
        num_stack: int, the number of frames to stack
        num_skip: int, the number of frames to skip
    """
    # Load dataset
    train_data = DataSet(data_type='train',
                         label_type=label_type,
                         batch_size=batch_size,
                         num_stack=num_stack,
                         num_skip=num_skip,
                         is_sorted=True)
    dev_data = DataSet(data_type='dev',
                       label_type=label_type,
                       batch_size=batch_size,
                       num_stack=num_stack,
                       num_skip=num_skip,
                       is_sorted=False)
    if label_type == 'character':
        test_data = DataSet(data_type='test',
                            label_type='character',
                            batch_size=1,
                            num_stack=num_stack,
                            num_skip=num_skip,
                            is_sorted=False)
    else:
        test_data = DataSet(data_type='test',
                            label_type='phone39',
                            batch_size=1,
                            num_stack=num_stack,
                            num_skip=num_skip,
                            is_sorted=False)

    # Tell TensorFlow that the model will be built into the default graph
    with tf.Graph().as_default():

        # Define placeholders
        network.inputs = tf.placeholder(tf.float32,
                                        shape=[None, None, network.input_size],
                                        name='input')
        indices_pl = tf.placeholder(tf.int64, name='indices')
        values_pl = tf.placeholder(tf.int32, name='values')
        shape_pl = tf.placeholder(tf.int64, name='shape')
        network.labels = tf.SparseTensor(indices_pl, values_pl, shape_pl)
        network.inputs_seq_len = tf.placeholder(tf.int64,
                                                shape=[None],
                                                name='inputs_seq_len')
        network.keep_prob_input = tf.placeholder(tf.float32,
                                                 name='keep_prob_input')
        network.keep_prob_hidden = tf.placeholder(tf.float32,
                                                  name='keep_prob_hidden')

        # Add to the graph each operation (including model definition)
        loss_op, logits = network.compute_loss(network.inputs, network.labels,
                                               network.inputs_seq_len,
                                               network.keep_prob_input,
                                               network.keep_prob_hidden)
        train_op = network.train(loss_op,
                                 optimizer=optimizer,
                                 learning_rate_init=float(learning_rate),
                                 is_scheduled=False)
        decode_op = network.decoder(logits,
                                    network.inputs_seq_len,
                                    decode_type='beam_search',
                                    beam_width=20)
        ler_op = network.compute_ler(decode_op, network.labels)

        # Build the summary tensor based on the TensorFlow collection of
        # summaries
        summary_train = tf.summary.merge(network.summaries_train)
        summary_dev = tf.summary.merge(network.summaries_dev)

        # Add the variable initializer operation
        init_op = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints
        saver = tf.train.Saver(max_to_keep=None)

        # Count total parameters
        parameters_dict, total_parameters = count_total_parameters(
            tf.trainable_variables())
        for parameter_name in sorted(parameters_dict.keys()):
            print("%s %d" % (parameter_name, parameters_dict[parameter_name]))
        print("Total %d variables, %s M parameters" %
              (len(parameters_dict.keys()), "{:,}".format(
                  total_parameters / 1000000)))

        # Make mini-batch generator
        mini_batch_train = train_data.next_batch()
        mini_batch_dev = dev_data.next_batch()

        csv_steps, csv_loss_train, csv_loss_dev = [], [], []
        csv_ler_train, csv_ler_dev = [], []
        # Create a session for running operation on the graph
        with tf.Session() as sess:

            # Instantiate a SummaryWriter to output summaries and the graph
            summary_writer = tf.summary.FileWriter(network.model_dir,
                                                   sess.graph)

            # Initialize parameters
            sess.run(init_op)

            # Train model
            iter_per_epoch = int(train_data.data_num / batch_size)
            train_step = train_data.data_num / batch_size
            if train_step != int(train_step):
                iter_per_epoch += 1
            max_steps = iter_per_epoch * epoch_num
            start_time_train = time.time()
            start_time_epoch = time.time()
            start_time_step = time.time()
            error_best = 1
            for step in range(max_steps):

                # Create feed dictionary for next mini batch (train)
                inputs, labels_st, inputs_seq_len, _ = mini_batch_train.__next__(
                )
                feed_dict_train = {
                    network.inputs: inputs,
                    network.labels: labels_st,
                    network.inputs_seq_len: inputs_seq_len,
                    network.keep_prob_input: network.dropout_ratio_input,
                    network.keep_prob_hidden: network.dropout_ratio_hidden,
                    network.lr: learning_rate
                }

                # Create feed dictionary for next mini batch (dev)
                inputs, labels_st, inputs_seq_len, _ = mini_batch_dev.__next__(
                )
                feed_dict_dev = {
                    network.inputs: inputs,
                    network.labels: labels_st,
                    network.inputs_seq_len: inputs_seq_len,
                    network.keep_prob_input: network.dropout_ratio_input,
                    network.keep_prob_hidden: network.dropout_ratio_hidden
                }

                # Update parameters
                sess.run(train_op, feed_dict=feed_dict_train)

                if (step + 1) % 10 == 0:

                    # Compute loss
                    loss_train = sess.run(loss_op, feed_dict=feed_dict_train)
                    loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev)
                    csv_steps.append(step)
                    csv_loss_train.append(loss_train)
                    csv_loss_dev.append(loss_dev)

                    # Change to evaluation mode
                    feed_dict_train[network.keep_prob_input] = 1.0
                    feed_dict_train[network.keep_prob_hidden] = 1.0
                    feed_dict_dev[network.keep_prob_input] = 1.0
                    feed_dict_dev[network.keep_prob_hidden] = 1.0

                    # Compute accuracy & update event file
                    ler_train, summary_str_train = sess.run(
                        [ler_op, summary_train], feed_dict=feed_dict_train)
                    ler_dev, summary_str_dev = sess.run(
                        [ler_op, summary_dev], feed_dict=feed_dict_dev)
                    csv_ler_train.append(ler_train)
                    csv_ler_dev.append(ler_dev)
                    summary_writer.add_summary(summary_str_train, step + 1)
                    summary_writer.add_summary(summary_str_dev, step + 1)
                    summary_writer.flush()

                    duration_step = time.time() - start_time_step
                    print(
                        "Step %d: loss = %.3f (%.3f) / ler = %.4f (%.4f) (%.3f min)"
                        % (step + 1, loss_train, loss_dev, ler_train, ler_dev,
                           duration_step / 60))
                    sys.stdout.flush()
                    start_time_step = time.time()

                # Save checkpoint and evaluate model per epoch
                if (step + 1) % iter_per_epoch == 0 or (step + 1) == max_steps:
                    duration_epoch = time.time() - start_time_epoch
                    epoch = (step + 1) // iter_per_epoch
                    print('-----EPOCH:%d (%.3f min)-----' %
                          (epoch, duration_epoch / 60))

                    # Save model (check point)
                    checkpoint_file = join(network.model_dir, 'model.ckpt')
                    save_path = saver.save(sess,
                                           checkpoint_file,
                                           global_step=epoch)
                    print("Model saved in file: %s" % save_path)

                    if epoch >= 10:
                        start_time_eval = time.time()

                        if label_type == 'character':
                            print('=== Dev Data Evaluation ===')
                            cer_dev_epoch = do_eval_cer(session=sess,
                                                        decode_op=decode_op,
                                                        network=network,
                                                        dataset=dev_data)
                            print('  CER: %f %%' % (cer_dev_epoch * 100))

                            if cer_dev_epoch < error_best:
                                error_best = cer_dev_epoch
                                print('■■■ ↑Best Score (CER)↑ ■■■')

                                print('=== Test Data Evaluation ===')
                                cer_test = do_eval_cer(session=sess,
                                                       decode_op=decode_op,
                                                       network=network,
                                                       dataset=test_data,
                                                       eval_batch_size=1)
                                print('  CER: %f %%' % (cer_test * 100))

                        else:
                            print('=== Dev Data Evaluation ===')
                            per_dev_epoch = do_eval_per(
                                session=sess,
                                decode_op=decode_op,
                                per_op=ler_op,
                                network=network,
                                dataset=dev_data,
                                train_label_type=label_type)
                            print('  PER: %f %%' % (per_dev_epoch * 100))

                            if per_dev_epoch < error_best:
                                error_best = per_dev_epoch
                                print('■■■ ↑Best Score (PER)↑ ■■■')

                                print('=== Test Data Evaluation ===')
                                per_test = do_eval_per(
                                    session=sess,
                                    decode_op=decode_op,
                                    per_op=ler_op,
                                    network=network,
                                    dataset=test_data,
                                    train_label_type=label_type,
                                    eval_batch_size=1)
                                print('  PER: %f %%' % (per_test * 100))

                        duration_eval = time.time() - start_time_eval
                        print('Evaluation time: %.3f min' %
                              (duration_eval / 60))

                start_time_epoch = time.time()
                start_time_step = time.time()

            duration_train = time.time() - start_time_train
            print('Total time: %.3f hour' % (duration_train / 3600))

            # Save train & dev loss, ler
            save_loss(csv_steps,
                      csv_loss_train,
                      csv_loss_dev,
                      save_path=network.model_dir)
            save_ler(csv_steps,
                     csv_ler_train,
                     csv_ler_dev,
                     save_path=network.model_dir)

            # Training was finished correctly
            with open(join(network.model_dir, 'complete.txt'), 'w') as f:
                f.write('')
示例#12
0
def do_fine_tune(network,
                 optimizer,
                 learning_rate,
                 batch_size,
                 epoch_num,
                 label_type,
                 num_stack,
                 num_skip,
                 social_signal_type,
                 trained_model_path,
                 restore_epoch=None):
    """Run training.
    Args:
        network: network to train
        optimizer: adam or adadelta or rmsprop
        learning_rate: initial learning rate
        batch_size: size of mini batch
        epoch_num: epoch num to train
        label_type: phone or character
        num_stack: int, the number of frames to stack
        num_skip: int, the number of frames to skip
        social_signal_type: insert or insert2 or insert3 or remove
        trained_model_path: path to the pre-trained model
        restore_epoch: epoch of the model to restore
    """
    # Tell TensorFlow that the model will be built into the default graph
    with tf.Graph().as_default():
        # Read dataset
        train_data = DataSetDialog(data_type='train',
                                   label_type=label_type,
                                   social_signal_type=social_signal_type,
                                   num_stack=num_stack,
                                   num_skip=num_skip,
                                   is_sorted=True)
        dev_data = DataSetDialog(data_type='dev',
                                 label_type=label_type,
                                 social_signal_type=social_signal_type,
                                 num_stack=num_stack,
                                 num_skip=num_skip,
                                 is_sorted=False)
        test_data = DataSetDialog(data_type='test',
                                  label_type=label_type,
                                  social_signal_type=social_signal_type,
                                  num_stack=num_stack,
                                  num_skip=num_skip,
                                  is_sorted=False)
        # TODO:作る
        # eval1_data = DataSet(data_type='eval1', label_type=label_type,
        #                      social_signal_type=social_signal_type,
        #                      num_stack=num_stack, num_skip=num_skip,
        #                      is_sorted=False)
        # eval2_data = DataSet(data_type='eval2', label_type=label_type,
        #                      social_signal_type=social_signal_type,
        #                      num_stack=num_stack, num_skip=num_skip,
        #                      is_sorted=False)
        # eval3_data = DataSet(data_type='eval3', label_type=label_type,
        #                      social_signal_type=social_signal_type,
        #                      num_stack=num_stack, num_skip=num_skip,
        #                      is_sorted=False)

        # Add to the graph each operation
        loss_op = network.loss()
        train_op = network.train(optimizer=optimizer,
                                 learning_rate_init=learning_rate,
                                 is_scheduled=False)
        decode_op = network.decoder(decode_type='beam_search', beam_width=20)
        per_op = network.ler(decode_op)

        # Build the summary tensor based on the TensorFlow collection of
        # summaries
        summary_train = tf.summary.merge(network.summaries_train)
        summary_dev = tf.summary.merge(network.summaries_dev)

        # Add the variable initializer operation
        init_op = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints
        saver = tf.train.Saver(max_to_keep=None)

        # Count total parameters
        parameters_dict, total_parameters = count_total_parameters(
            tf.trainable_variables())
        for parameter_name in sorted(parameters_dict.keys()):
            print("%s %d" % (parameter_name, parameters_dict[parameter_name]))
        print("Total %d variables, %s M parameters" %
              (len(parameters_dict.keys()), "{:,}".format(
                  total_parameters / 1000000)))

        csv_steps = []
        csv_train_loss = []
        csv_dev_loss = []

        # Create a session for running operation on the graph
        with tf.Session() as sess:
            # Instantiate a SummaryWriter to output summaries and the graph
            summary_writer = tf.summary.FileWriter(network.model_dir,
                                                   sess.graph)

            # Initialize parameters
            sess.run(init_op)

            # Restore pre-trained model's parameters
            ckpt = tf.train.get_checkpoint_state(trained_model_path)
            if ckpt:
                # Use last saved model
                model_path = ckpt.model_checkpoint_path
                if restore_epoch is not None:
                    model_path = model_path.split('/')[:-1]
                    model_path = '/'.join(model_path) + \
                        '/model.ckpt-' + str(restore_epoch)
            else:
                raise ValueError('There are not any checkpoints.')
            exclude = ['output/Variable', 'output/Variable_1']
            variables_to_restore = slim.get_variables_to_restore(
                exclude=exclude)
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(sess, model_path)
            print("Model restored: " + model_path)

            # Train model
            iter_per_epoch = int(train_data.data_num / batch_size)
            if (train_data.data_num / batch_size) != int(
                    train_data.data_num / batch_size):
                iter_per_epoch += 1
            max_steps = iter_per_epoch * epoch_num
            start_time_train = time.time()
            start_time_epoch = time.time()
            start_time_step = time.time()
            fmean_best = 0
            for step in range(max_steps):
                # Create feed dictionary for next mini batch (train)
                inputs, labels, seq_len, _ = train_data.next_batch(
                    batch_size=batch_size)
                indices, values, dense_shape = list2sparsetensor(labels)
                feed_dict_train = {
                    network.inputs_pl: inputs,
                    network.label_indices_pl: indices,
                    network.label_values_pl: values,
                    network.label_shape_pl: dense_shape,
                    network.seq_len_pl: seq_len,
                    network.keep_prob_input_pl: network.dropout_ratio_input,
                    network.keep_prob_hidden_pl: network.dropout_ratio_hidden,
                    network.lr_pl: learning_rate
                }

                # Create feed dictionary for next mini batch (dev)
                inputs, labels, seq_len, _ = dev_data.next_batch(
                    batch_size=batch_size)
                indices, values, dense_shape = list2sparsetensor(labels)
                feed_dict_dev = {
                    network.inputs_pl: inputs,
                    network.label_indices_pl: indices,
                    network.label_values_pl: values,
                    network.label_shape_pl: dense_shape,
                    network.seq_len_pl: seq_len,
                    network.keep_prob_input_pl: network.dropout_ratio_input,
                    network.keep_prob_hidden_pl: network.dropout_ratio_hidden
                }

                # Update parameters & compute loss
                _, loss_train = sess.run([train_op, loss_op],
                                         feed_dict=feed_dict_train)
                loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev)
                csv_steps.append(step)
                csv_train_loss.append(loss_train)
                csv_dev_loss.append(loss_dev)

                if (step + 1) % 10 == 0:
                    # Change feed dict for evaluation
                    feed_dict_train[network.keep_prob_input_pl] = 1.0
                    feed_dict_train[network.keep_prob_hidden_pl] = 1.0
                    feed_dict_dev[network.keep_prob_input_pl] = 1.0
                    feed_dict_dev[network.keep_prob_hidden_pl] = 1.0

                    # Compute accuracy & \update event file
                    ler_train, summary_str_train = sess.run(
                        [per_op, summary_train], feed_dict=feed_dict_train)
                    ler_dev, summary_str_dev, labels_st = sess.run(
                        [per_op, summary_dev, decode_op],
                        feed_dict=feed_dict_dev)
                    summary_writer.add_summary(summary_str_train, step + 1)
                    summary_writer.add_summary(summary_str_dev, step + 1)
                    summary_writer.flush()

                    # Decode
                    # try:
                    #     labels_pred = sparsetensor2list(labels_st, batch_size)
                    # except:
                    #     labels_pred = [[0] * batch_size]

                    duration_step = time.time() - start_time_step
                    print(
                        'Step %d: loss = %.3f (%.3f) / ler = %.4f (%.4f) (%.3f min)'
                        % (step + 1, loss_train, loss_dev, ler_train, ler_dev,
                           duration_step / 60))

                    # if label_type == 'character':
                    #     if social_signal_type == 'remove':
                    #         map_file_path = '../evaluation/mapping_files/ctc/char2num_remove.txt'
                    #     else:
                    #         map_file_path = '../evaluation/mapping_files/ctc/char2num_' + \
                    #             social_signal_type + '.txt'
                    #     print('True: %s' % num2char(labels[-1], map_file_path))
                    #     print('Pred: %s' % num2char(
                    #         labels_pred[-1], map_file_path))
                    # elif label_type == 'phone':
                    #     if social_signal_type == 'remove':
                    #         map_file_path = '../evaluation/mapping_files/ctc/phone2num_remove.txt'
                    #     else:
                    #         map_file_path = '../evaluation/mapping_files/ctc/phone2num_' + \
                    #             social_signal_type + '.txt'
                    #     print('True: %s' % num2phone(
                    #         labels[-1], map_file_path))
                    #     print('Pred: %s' % num2phone(
                    #         labels_pred[-1], map_file_path))

                    sys.stdout.flush()
                    start_time_step = time.time()

                # Save checkpoint and evaluate model per epoch
                if (step + 1) % iter_per_epoch == 0 or (step + 1) == max_steps:
                    duration_epoch = time.time() - start_time_epoch
                    epoch = (step + 1) // iter_per_epoch
                    print('-----EPOCH:%d (%.3f min)-----' %
                          (epoch, duration_epoch / 60))

                    # Save model (check point)
                    checkpoint_file = os.path.join(network.model_dir,
                                                   'model.ckpt')
                    save_path = saver.save(sess,
                                           checkpoint_file,
                                           global_step=epoch)
                    print("Model saved in file: %s" % save_path)

                    start_time_eval = time.time()
                    if label_type == 'character':
                        print('■Dev Evaluation:■')
                        fmean_epoch = do_eval_fmeasure(
                            session=sess,
                            decode_op=decode_op,
                            network=network,
                            dataset=dev_data,
                            label_type=label_type,
                            social_signal_type=social_signal_type)
                        # error_epoch = do_eval_cer(session=sess,
                        #                           decode_op=decode_op,
                        #                           network=network,
                        #                           dataset=dev_data,
                        #                           eval_batch_size=batch_size)

                        if fmean_epoch > fmean_best:
                            fmean_best = fmean_epoch
                            print('■■■ ↑Best Score (F-measure)↑ ■■■')

                            do_eval_fmeasure(
                                session=sess,
                                decode_op=decode_op,
                                network=network,
                                dataset=test_data,
                                label_type=label_type,
                                social_signal_type=social_signal_type)
                            # print('■eval1 Evaluation:■')
                            # do_eval_cer(session=sess, decode_op=decode_op,
                            #             network=network, dataset=eval1_data,
                            #             eval_batch_size=batch_size)
                            # print('■eval2 Evaluation:■')
                            # do_eval_cer(session=sess, decode_op=decode_op,
                            #             network=network, dataset=eval2_data,
                            #             eval_batch_size=batch_size)
                            # print('■eval3 Evaluation:■')
                            # do_eval_cer(session=sess, decode_op=decode_op,
                            #             network=network, dataset=eval3_data,
                            #             eval_batch_size=batch_size)

                    else:
                        print('■Dev Evaluation:■')
                        fmean_epoch = do_eval_fmeasure(
                            session=sess,
                            decode_op=decode_op,
                            network=network,
                            dataset=dev_data,
                            label_type=label_type,
                            social_signal_type=social_signal_type)
                        # error_epoch = do_eval_per(session=sess,
                        #                           per_op=per_op,
                        #                           network=network,
                        #                           dataset=dev_data,
                        #                           eval_batch_size=batch_size)

                        if fmean_epoch < fmean_best:
                            fmean_best = fmean_epoch
                            print('■■■ ↑Best Score (F-measure)↑ ■■■')

                            do_eval_fmeasure(
                                session=sess,
                                decode_op=decode_op,
                                network=network,
                                dataset=test_data,
                                label_type=label_type,
                                social_signal_type=social_signal_type)
                            # print('■eval1 Evaluation:■')
                            # do_eval_per(session=sess, per_op=per_op,
                            #             network=network, dataset=eval1_data,
                            #             eval_batch_size=batch_size)
                            # print('■eval2 Evaluation:■')
                            # do_eval_per(session=sess, per_op=per_op,
                            #             network=network, dataset=eval2_data,
                            #             eval_batch_size=batch_size)
                            # print('■eval3 Evaluation:■')
                            # do_eval_per(session=sess, per_op=per_op,
                            #             network=network, dataset=eval3_data,
                            #             eval_batch_size=batch_size)

                    duration_eval = time.time() - start_time_eval
                    print('Evaluation time: %.3f min' % (duration_eval / 60))

                    start_time_epoch = time.time()
                    start_time_step = time.time()

            duration_train = time.time() - start_time_train
            print('Total time: %.3f hour' % (duration_train / 3600))

            # Save train & dev loss
            save_loss(csv_steps,
                      csv_train_loss,
                      csv_dev_loss,
                      save_path=network.model_dir)

            # Training was finished correctly
            with open(os.path.join(network.model_dir, 'complete.txt'),
                      'w') as f:
                f.write('')
    def check_encode(self, encoder_type, lstm_impl=None):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  lstm_impl: %s' % lstm_impl)
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 4
            splice = 11 if encoder_type in ['vgg_blstm', 'vgg_lstm', 'vgg_wang',
                                            'resnet_wang', 'cnn_zhang'] else 1
            inputs, _, inputs_seq_len = generate_data(
                label_type='character',
                model='ctc',
                batch_size=batch_size,
                splice=splice)
            frame_num, input_size = inputs[0].shape

            # Define model graph
            if encoder_type in ['blstm', 'lstm']:
                encoder = load(encoder_type)(
                    num_units=256,
                    num_layers=5,
                    num_classes=0,  # return hidden states
                    lstm_impl=lstm_impl,
                    parameter_init=0.1)
            elif encoder_type in ['bgru', 'gru']:
                encoder = load(encoder_type)(
                    num_units=256,
                    num_layers=5,
                    num_classes=0,  # return hidden states
                    parameter_init=0.1)
            elif encoder_type in ['vgg_blstm', 'vgg_lstm']:
                encoder = load(encoder_type)(
                    input_size=input_size // 11,
                    splice=11,
                    num_units=256,
                    num_layers=5,
                    num_classes=0,  # return hidden states
                    lstm_impl=lstm_impl,
                    parameter_init=0.1)
            elif encoder_type in ['multitask_blstm', 'multitask_lstm']:
                encoder = load(encoder_type)(
                    num_units=256,
                    num_layers_main=5,
                    num_layers_sub=3,
                    num_classes_main=0,  # return hidden states
                    num_classes_sub=0,  # return hidden states
                    lstm_impl=lstm_impl,
                    parameter_init=0.1)
            elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                encoder = load(encoder_type)(
                    input_size=input_size // 11,
                    splice=11,
                    num_classes=27,
                    parameter_init=0.1)
                # NOTE: topology is pre-defined
            else:
                raise NotImplementedError

            # Create placeholders
            inputs_pl = tf.placeholder(tf.float32,
                                       shape=[None, None, input_size],
                                       name='inputs')
            inputs_seq_len_pl = tf.placeholder(tf.int32,
                                               shape=[None],
                                               name='inputs_seq_len')
            keep_prob_input_pl = tf.placeholder(tf.float32,
                                                name='keep_prob_input')
            keep_prob_hidden_pl = tf.placeholder(tf.float32,
                                                 name='keep_prob_hidden')
            keep_prob_output_pl = tf.placeholder(tf.float32,
                                                 name='keep_prob_output')

            # operation for forward computation
            if encoder_type in ['multitask_blstm', 'multitask_lstm']:
                hidden_states_op, final_state_op, hidden_states_sub_op, final_state_sub_op = encoder(
                    inputs=inputs_pl,
                    inputs_seq_len=inputs_seq_len_pl,
                    keep_prob_input=keep_prob_input_pl,
                    keep_prob_hidden=keep_prob_hidden_pl,
                    keep_prob_output=keep_prob_output_pl)
            else:
                hidden_states_op, final_state_op = encoder(
                    inputs=inputs_pl,
                    inputs_seq_len=inputs_seq_len_pl,
                    keep_prob_input=keep_prob_input_pl,
                    keep_prob_hidden=keep_prob_hidden_pl,
                    keep_prob_output=keep_prob_output_pl)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()),
                   "{:,}".format(total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                inputs_pl: inputs,
                inputs_seq_len_pl: inputs_seq_len,
                keep_prob_input_pl: 0.9,
                keep_prob_hidden_pl: 0.9,
                keep_prob_output_pl: 1.0
            }

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Make prediction
                if encoder_type in ['multitask_blstm', 'multitask_lstm']:
                    hidden_states, final_state, hidden_states_sub, final_state_sub = sess.run(
                        [hidden_states_op, final_state_op, hidden_states_sub_op, final_state_sub_op], feed_dict=feed_dict)
                elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                    hidden_states = sess.run(
                        hidden_states_op, feed_dict=feed_dict)
                else:
                    hidden_states, final_state = sess.run(
                        [hidden_states_op, final_state_op], feed_dict=feed_dict)

                if encoder_type in ['blstm', 'bgru', 'vgg_blstm', 'multitask_blstm']:
                    self.assertEqual(
                        (batch_size, frame_num, encoder.num_units * 2), hidden_states.shape)

                    if encoder_type in ['blstm', 'vgg_blstm', 'multitask_blstm']:
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].c.shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].h.shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[1].c.shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[1].h.shape)

                        if encoder_type == 'multitask_blstm':
                            self.assertEqual(
                                (batch_size, frame_num, encoder.num_units * 2), hidden_states_sub.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[0].c.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[0].h.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[1].c.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[1].h.shape)
                    else:
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[1].shape)

                elif encoder_type in ['lstm', 'gru', 'vgg_lstm']:
                    self.assertEqual(
                        (batch_size, frame_num, encoder.num_units), hidden_states.shape)

                    if encoder_type in ['lstm', 'vgg_lstm', 'multitask_lstm']:
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].c.shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].h.shape)

                        if encoder_type == 'multitask_lstm':
                            self.assertEqual(
                                (batch_size, frame_num, encoder.num_units), hidden_states_sub.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[0].c.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[0].h.shape)
                    else:
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].shape)

                elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                    self.assertEqual(
                        (frame_num, batch_size, encoder.num_classes), hidden_states.shape)
    def check(self, encoder_type, label_type='character',
              lstm_impl=None, time_major=True, save_params=False):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  label_type: %s' % label_type)
        print('  lstm_impl: %s' % lstm_impl)
        print('  time_major: %s' % str(time_major))
        print('  save_params: %s' % str(save_params))
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 2
            splice = 11 if encoder_type in ['vgg_blstm', 'vgg_lstm', 'cnn_zhang',
                                            'vgg_wang', 'resnet_wang', 'cldnn_wang'] else 1
            num_stack = 2
            inputs, labels, inputs_seq_len = generate_data(
                label_type=label_type,
                model='ctc',
                batch_size=batch_size,
                num_stack=num_stack,
                splice=splice)
            # NOTE: input_size must be even number when using CudnnLSTM

            # Define model graph
            num_classes = 27 if label_type == 'character' else 61
            model = CTC(encoder_type=encoder_type,
                        input_size=inputs[0].shape[-1] // splice // num_stack,
                        splice=splice,
                        num_stack=num_stack,
                        num_units=256,
                        num_layers=2,
                        num_classes=num_classes,
                        lstm_impl=lstm_impl,
                        parameter_init=0.1,
                        clip_grad_norm=5.0,
                        clip_activation=50,
                        num_proj=256,
                        weight_decay=1e-10,
                        # bottleneck_dim=50,
                        bottleneck_dim=None,
                        time_major=time_major)

            # Define placeholders
            model.create_placeholders()
            learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

            # Add to the graph each operation
            loss_op, logits = model.compute_loss(
                model.inputs_pl_list[0],
                model.labels_pl_list[0],
                model.inputs_seq_len_pl_list[0],
                model.keep_prob_pl_list[0])
            train_op = model.train(loss_op,
                                   optimizer='nestrov',
                                   learning_rate=learning_rate_pl)
            # NOTE: Adam does not run on CudnnLSTM
            decode_op = model.decoder(logits,
                                      model.inputs_seq_len_pl_list[0],
                                      beam_width=20)
            ler_op = model.compute_ler(decode_op, model.labels_pl_list[0])

            # Define learning rate controller
            learning_rate = 1e-4
            lr_controller = Controller(learning_rate_init=learning_rate,
                                       decay_start_epoch=50,
                                       decay_rate=0.9,
                                       decay_patient_epoch=10,
                                       lower_better=True)

            if save_params:
                # Create a saver for writing training checkpoints
                saver = tf.train.Saver(max_to_keep=None)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            if lstm_impl != 'CudnnLSTM':
                parameters_dict, total_parameters = count_total_parameters(
                    tf.trainable_variables())
                for parameter_name in sorted(parameters_dict.keys()):
                    print("%s %d" %
                          (parameter_name, parameters_dict[parameter_name]))
                print("Total %d variables, %s M parameters" %
                      (len(parameters_dict.keys()),
                       "{:,}".format(total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                model.inputs_pl_list[0]: inputs,
                model.labels_pl_list[0]: list2sparsetensor(labels, padded_value=-1),
                model.inputs_seq_len_pl_list[0]: inputs_seq_len,
                model.keep_prob_pl_list[0]: 1.0,
                learning_rate_pl: learning_rate
            }

            idx2phone = Idx2phone(map_file_path='./phone61.txt')

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Wrapper for tfdbg
                # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

                # Train model
                max_steps = 1000
                start_time_step = time.time()
                for step in range(max_steps):

                    # for debug
                    # encoder_outputs = sess.run(
                    #     model.encoder_outputs, feed_dict)
                    # print(encoder_outputs.shape)

                    # Compute loss
                    _, loss_train = sess.run(
                        [train_op, loss_op], feed_dict=feed_dict)

                    # Gradient check
                    # grads = sess.run(model.clipped_grads,
                    #                  feed_dict=feed_dict)
                    # for grad in grads:
                    #     print(np.max(grad))

                    if (step + 1) % 10 == 0:
                        # Change to evaluation mode
                        feed_dict[model.keep_prob_pl_list[0]] = 1.0

                        # Compute accuracy
                        ler_train = sess.run(ler_op, feed_dict=feed_dict)

                        duration_step = time.time() - start_time_step
                        print('Step %d: loss = %.3f / ler = %.3f (%.3f sec) / lr = %.5f' %
                              (step + 1, loss_train, ler_train, duration_step, learning_rate))
                        start_time_step = time.time()

                        # Decode
                        labels_pred_st = sess.run(
                            decode_op, feed_dict=feed_dict)

                        # Visualize
                        try:
                            labels_pred = sparsetensor2list(
                                labels_pred_st, batch_size=batch_size)
                            if label_type == 'character':
                                print('Ref: %s' % idx2alpha(labels[0]))
                                print('Hyp: %s' % idx2alpha(labels_pred[0]))
                            else:
                                print('Ref: %s' % idx2phone(labels[0]))
                                print('Hyp: %s' % idx2phone(labels_pred[0]))

                        except IndexError:
                            if label_type == 'character':
                                print('Ref: %s' % idx2alpha(labels[0]))
                                print('Hyp: %s' % '')
                            else:
                                print('Ref: %s' % idx2phone(labels[0]))
                                print('Hyp: %s' % '')
                            # NOTE: This is for no prediction

                        if ler_train < 0.1:
                            print('Modle is Converged.')
                            if save_params:
                                # Save model (check point)
                                checkpoint_file = './model.ckpt'
                                save_path = saver.save(
                                    sess, checkpoint_file, global_step=2)
                                print("Model saved in file: %s" % save_path)
                            break

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=step,
                            value=ler_train)
                        feed_dict[learning_rate_pl] = learning_rate
示例#15
0
    def check(self, encoder_type, attention_type, label_type='character'):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  attention_type: %s' % attention_type)
        print('  label_type: %s' % label_type)
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 4
            inputs, labels, inputs_seq_len, labels_seq_len = generate_data(
                label_type=label_type,
                model='attention',
                batch_size=batch_size)

            # Define model graph
            num_classes = 27 if label_type == 'character' else 61
            model = AttentionSeq2Seq(input_size=inputs[0].shape[1],
                                     encoder_type=encoder_type,
                                     encoder_num_units=256,
                                     encoder_num_layers=2,
                                     encoder_num_proj=None,
                                     attention_type=attention_type,
                                     attention_dim=128,
                                     decoder_type='lstm',
                                     decoder_num_units=256,
                                     decoder_num_layers=1,
                                     embedding_dim=64,
                                     num_classes=num_classes,
                                     sos_index=num_classes,
                                     eos_index=num_classes + 1,
                                     max_decode_length=100,
                                     use_peephole=True,
                                     splice=1,
                                     parameter_init=0.1,
                                     clip_grad_norm=5.0,
                                     clip_activation_encoder=50,
                                     clip_activation_decoder=50,
                                     weight_decay=1e-8,
                                     time_major=True,
                                     sharpening_factor=1.0,
                                     logits_temperature=1.0)

            # Define placeholders
            model.create_placeholders()
            learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

            # Add to the graph each operation
            loss_op, logits, decoder_outputs_train, decoder_outputs_infer = model.compute_loss(
                model.inputs_pl_list[0],
                model.labels_pl_list[0],
                model.inputs_seq_len_pl_list[0],
                model.labels_seq_len_pl_list[0],
                model.keep_prob_encoder_pl_list[0],
                model.keep_prob_decoder_pl_list[0],
                model.keep_prob_embedding_pl_list[0])
            train_op = model.train(loss_op,
                                   optimizer='adam',
                                   learning_rate=learning_rate_pl)
            decode_op_train, decode_op_infer = model.decode(
                decoder_outputs_train, decoder_outputs_infer)
            ler_op = model.compute_ler(model.labels_st_true_pl,
                                       model.labels_st_pred_pl)

            # Define learning rate controller
            learning_rate = 1e-3
            lr_controller = Controller(learning_rate_init=learning_rate,
                                       decay_start_epoch=20,
                                       decay_rate=0.9,
                                       decay_patient_epoch=10,
                                       lower_better=True)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()),
                   "{:,}".format(total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                model.inputs_pl_list[0]: inputs,
                model.labels_pl_list[0]: labels,
                model.inputs_seq_len_pl_list[0]: inputs_seq_len,
                model.labels_seq_len_pl_list[0]: labels_seq_len,
                model.keep_prob_encoder_pl_list[0]: 0.8,
                model.keep_prob_decoder_pl_list[0]: 1.0,
                model.keep_prob_embedding_pl_list[0]: 1.0,
                learning_rate_pl: learning_rate
            }

            idx2phone = Idx2phone(map_file_path='./phone61.txt')

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Wrapper for tfdbg
                # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

                # Train model
                max_steps = 1000
                start_time_step = time.time()
                for step in range(max_steps):

                    # Compute loss
                    _, loss_train = sess.run(
                        [train_op, loss_op], feed_dict=feed_dict)

                    # Gradient check
                    # grads = sess.run(model.clipped_grads,
                    #                  feed_dict=feed_dict)
                    # for grad in grads:
                    #     print(np.max(grad))

                    if (step + 1) % 10 == 0:
                        # Change to evaluation mode
                        feed_dict[model.keep_prob_encoder_pl_list[0]] = 1.0
                        feed_dict[model.keep_prob_decoder_pl_list[0]] = 1.0
                        feed_dict[model.keep_prob_embedding_pl_list[0]] = 1.0

                        # Predict class ids
                        predicted_ids_train, predicted_ids_infer = sess.run(
                            [decode_op_train, decode_op_infer],
                            feed_dict=feed_dict)

                        # Compute accuracy
                        try:
                            feed_dict_ler = {
                                model.labels_st_true_pl: list2sparsetensor(
                                    labels, padded_value=model.eos_index),
                                model.labels_st_pred_pl: list2sparsetensor(
                                    predicted_ids_infer, padded_value=model.eos_index)
                            }
                            ler_train = sess.run(
                                ler_op, feed_dict=feed_dict_ler)
                        except IndexError:
                            ler_train = 1

                        duration_step = time.time() - start_time_step
                        print('Step %d: loss = %.3f / ler = %.3f (%.3f sec) / lr = %.5f' %
                              (step + 1, loss_train, ler_train, duration_step, learning_rate))
                        start_time_step = time.time()

                        # Visualize
                        if label_type == 'character':
                            print('True            : %s' %
                                  idx2alpha(labels[0]))
                            print('Pred (Training) : <%s' %
                                  idx2alpha(predicted_ids_train[0]))
                            print('Pred (Inference): <%s' %
                                  idx2alpha(predicted_ids_infer[0]))
                        else:
                            print('True            : %s' %
                                  idx2phone(labels[0]))
                            print('Pred (Training) : < %s' %
                                  idx2phone(predicted_ids_train[0]))
                            print('Pred (Inference): < %s' %
                                  idx2phone(predicted_ids_infer[0]))

                        if ler_train < 0.1:
                            print('Model is Converged.')
                            break

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=step,
                            value=ler_train)
                        feed_dict[learning_rate_pl] = learning_rate
    def check(self, encoder_type, lstm_impl, time_major=False):

        print('==================================================')
        print('  encoder_type: %s' % str(encoder_type))
        print('  lstm_impl: %s' % str(lstm_impl))
        print('  time_major: %s' % str(time_major))
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 2
            inputs, labels_char, labels_phone, inputs_seq_len = generate_data(
                label_type='multitask',
                model='ctc',
                batch_size=batch_size)

            # Define model graph
            num_classes_main = 27
            num_classes_sub = 61
            model = MultitaskCTC(
                encoder_type=encoder_type,
                input_size=inputs[0].shape[1],
                num_units=256,
                num_layers_main=2,
                num_layers_sub=1,
                num_classes_main=num_classes_main,
                num_classes_sub=num_classes_sub,
                main_task_weight=0.8,
                lstm_impl=lstm_impl,
                parameter_init=0.1,
                clip_grad_norm=5.0,
                clip_activation=50,
                num_proj=256,
                weight_decay=1e-8,
                # bottleneck_dim=50,
                bottleneck_dim=None,
                time_major=time_major)

            # Define placeholders
            model.create_placeholders()
            learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

            # Add to the graph each operation
            loss_op, logits_main, logits_sub = model.compute_loss(
                model.inputs_pl_list[0],
                model.labels_pl_list[0],
                model.labels_sub_pl_list[0],
                model.inputs_seq_len_pl_list[0],
                model.keep_prob_pl_list[0])
            train_op = model.train(
                loss_op,
                optimizer='adam',
                learning_rate=learning_rate_pl)
            decode_op_main, decode_op_sub = model.decoder(
                logits_main,
                logits_sub,
                model.inputs_seq_len_pl_list[0],
                beam_width=20)
            ler_op_main, ler_op_sub = model.compute_ler(
                decode_op_main, decode_op_sub,
                model.labels_pl_list[0], model.labels_sub_pl_list[0])

            # Define learning rate controller
            learning_rate = 1e-3
            lr_controller = Controller(learning_rate_init=learning_rate,
                                       decay_start_epoch=20,
                                       decay_rate=0.9,
                                       decay_patient_epoch=5,
                                       lower_better=True)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()),
                   "{:,}".format(total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                model.inputs_pl_list[0]: inputs,
                model.labels_pl_list[0]: list2sparsetensor(labels_char, padded_value=-1),
                model.labels_sub_pl_list[0]: list2sparsetensor(labels_phone, padded_value=-1),
                model.inputs_seq_len_pl_list[0]: inputs_seq_len,
                model.keep_prob_pl_list[0]: 0.9,
                learning_rate_pl: learning_rate
            }

            idx2phone = Idx2phone(map_file_path='./phone61.txt')

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Wrapper for tfdbg
                # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

                # Train model
                max_steps = 1000
                start_time_step = time.time()
                for step in range(max_steps):

                    # Compute loss
                    _, loss_train = sess.run(
                        [train_op, loss_op], feed_dict=feed_dict)

                    # Gradient check
                    # grads = sess.run(model.clipped_grads,
                    #                  feed_dict=feed_dict)
                    # for grad in grads:
                    #     print(np.max(grad))

                    if (step + 1) % 10 == 0:
                        # Change to evaluation mode
                        feed_dict[model.keep_prob_pl_list[0]] = 1.0

                        # Compute accuracy
                        ler_train_char, ler_train_phone = sess.run(
                            [ler_op_main, ler_op_sub], feed_dict=feed_dict)

                        duration_step = time.time() - start_time_step
                        print('Step %d: loss = %.3f / cer = %.3f / per = %.3f (%.3f sec) / lr = %.5f' %
                              (step + 1, loss_train, ler_train_char,
                               ler_train_phone, duration_step, learning_rate))
                        start_time_step = time.time()

                        # Visualize
                        labels_pred_char_st, labels_pred_phone_st = sess.run(
                            [decode_op_main, decode_op_sub],
                            feed_dict=feed_dict)
                        labels_pred_char = sparsetensor2list(
                            labels_pred_char_st, batch_size=batch_size)
                        labels_pred_phone = sparsetensor2list(
                            labels_pred_phone_st, batch_size=batch_size)

                        print('Character')
                        try:
                            print('  Ref: %s' % idx2alpha(labels_char[0]))
                            print('  Hyp: %s' % idx2alpha(labels_pred_char[0]))
                        except IndexError:
                            print('Character')
                            print('  Ref: %s' % idx2alpha(labels_char[0]))
                            print('  Hyp: %s' % '')

                        print('Phone')
                        try:
                            print('  Ref: %s' % idx2phone(labels_phone[0]))
                            print('  Hyp: %s' %
                                  idx2phone(labels_pred_phone[0]))
                        except IndexError:
                            print('  Ref: %s' % idx2phone(labels_phone[0]))
                            print('  Hyp: %s' % '')
                            # NOTE: This is for no prediction
                        print('-' * 30)

                        if ler_train_char < 0.1:
                            print('Modle is Converged.')
                            break

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=step,
                            value=ler_train_char)
                        feed_dict[learning_rate_pl] = learning_rate
def do_train(model, params, gpu_indices):
    """Run training.
    Args:
        model: the model to train
        params (dict): A dictionary of parameters
        gpu_indices (list): GPU indices
    """
    if 'kanji' in params['label_type']:
        map_file_path = '../metrics/mapping_files/' + \
            params['label_type'] + '_' + params['train_data_size'] + '.txt'
    elif 'kana' in params['label_type']:
        map_file_path = '../metrics/mapping_files/' + \
            params['label_type'] + '.txt'

    # Load dataset
    train_data = Dataset(
        data_type='train', train_data_size=params['train_data_size'],
        label_type=params['label_type'], map_file_path=map_file_path,
        batch_size=params['batch_size'], max_epoch=params['num_epoch'],
        splice=params['splice'],
        num_stack=params['num_stack'], num_skip=params['num_skip'],
        sort_utt=True, sort_stop_epoch=params['sort_stop_epoch'],
        num_gpu=len(gpu_indices))
    dev_data = Dataset(
        data_type='dev', train_data_size=params['train_data_size'],
        label_type=params['label_type'], map_file_path=map_file_path,
        batch_size=params['batch_size'], splice=params['splice'],
        num_stack=params['num_stack'], num_skip=params['num_skip'],
        sort_utt=False, num_gpu=len(gpu_indices))

    # Tell TensorFlow that the model will be built into the default graph
    with tf.Graph().as_default(), tf.device('/cpu:0'):

        # Create a variable to track the global step
        global_step = tf.Variable(0, name='global_step', trainable=False)

        # Set optimizer
        learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')
        optimizer = model._set_optimizer(
            params['optimizer'], learning_rate_pl)

        # Calculate the gradients for each model tower
        total_grads_and_vars, total_losses = [], []
        decode_ops_infer, ler_ops = [], []
        all_devices = ['/gpu:%d' % i_gpu for i_gpu in range(len(gpu_indices))]
        # NOTE: /cpu:0 is prepared for evaluation
        with tf.variable_scope(tf.get_variable_scope()):
            for i_gpu in range(len(all_devices)):
                with tf.device(all_devices[i_gpu]):
                    with tf.name_scope('tower_gpu%d' % i_gpu) as scope:

                        # Define placeholders in each tower
                        model.create_placeholders()

                        # Calculate the total loss for the current tower of the
                        # model. This function constructs the entire model but
                        # shares the variables across all towers.
                        tower_loss, tower_logits, tower_decoder_outputs_train, tower_decoder_outputs_infer = model.compute_loss(
                            model.inputs_pl_list[i_gpu],
                            model.labels_pl_list[i_gpu],
                            model.inputs_seq_len_pl_list[i_gpu],
                            model.labels_seq_len_pl_list[i_gpu],
                            model.keep_prob_encoder_pl_list[i_gpu],
                            model.keep_prob_decoder_pl_list[i_gpu],
                            model.keep_prob_embedding_pl_list[i_gpu],
                            scope)
                        tower_loss = tf.expand_dims(tower_loss, axis=0)
                        total_losses.append(tower_loss)

                        # Reuse variables for the next tower
                        tf.get_variable_scope().reuse_variables()

                        # Calculate the gradients for the batch of data on this
                        # tower
                        tower_grads_and_vars = optimizer.compute_gradients(
                            tower_loss)

                        # Gradient clipping
                        tower_grads_and_vars = model._clip_gradients(
                            tower_grads_and_vars)

                        # TODO: Optionally add gradient noise

                        # Keep track of the gradients across all towers
                        total_grads_and_vars.append(tower_grads_and_vars)

                        # Add to the graph each operation per tower
                        _, decode_op_tower_infer = model.decode(
                            tower_decoder_outputs_train,
                            tower_decoder_outputs_infer)
                        decode_ops_infer.append(decode_op_tower_infer)
                        # ler_op_tower = model.compute_ler(
                        #     decode_op_tower, model.labels_pl_list[i_gpu])
                        ler_op_tower = model.compute_ler(
                            model.labels_st_true_pl_list[i_gpu],
                            model.labels_st_pred_pl_list[i_gpu])
                        ler_op_tower = tf.expand_dims(ler_op_tower, axis=0)
                        ler_ops.append(ler_op_tower)

        # Aggregate losses, then calculate average loss
        total_losses = tf.concat(axis=0, values=total_losses)
        loss_op = tf.reduce_mean(total_losses, axis=0)
        ler_ops = tf.concat(axis=0, values=ler_ops)
        ler_op = tf.reduce_mean(ler_ops, axis=0)

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers
        average_grads_and_vars = average_gradients(total_grads_and_vars)

        # Apply the gradients to adjust the shared variables.
        train_op = optimizer.apply_gradients(average_grads_and_vars,
                                             global_step=global_step)

        # Define learning rate controller
        lr_controller = Controller(
            learning_rate_init=params['learning_rate'],
            decay_start_epoch=params['decay_start_epoch'],
            decay_rate=params['decay_rate'],
            decay_patient_epoch=params['decay_patient_epoch'],
            lower_better=True)

        # Build the summary tensor based on the TensorFlow collection of
        # summaries
        summary_train = tf.summary.merge(model.summaries_train)
        summary_dev = tf.summary.merge(model.summaries_dev)

        # Add the variable initializer operation
        init_op = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints
        saver = tf.train.Saver(max_to_keep=None)

        # Count total parameters
        parameters_dict, total_parameters = count_total_parameters(
            tf.trainable_variables())
        for parameter_name in sorted(parameters_dict.keys()):
            print("%s %d" %
                  (parameter_name, parameters_dict[parameter_name]))
        print("Total %d variables, %s M parameters" %
              (len(parameters_dict.keys()),
               "{:,}".format(total_parameters / 1000000)))

        csv_steps, csv_loss_train, csv_loss_dev = [], [], []
        csv_ler_train, csv_ler_dev = [], []
        # Create a session for running operation on the graph
        # NOTE: Start running operations on the Graph. allow_soft_placement
        # must be set to True to build towers on GPU, as some of the ops do not
        # have GPU implementations.
        with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                              log_device_placement=False)) as sess:

            # Instantiate a SummaryWriter to output summaries and the graph
            summary_writer = tf.summary.FileWriter(
                model.save_path, sess.graph)

            # Initialize param
            sess.run(init_op)

            # Train model
            start_time_train = time.time()
            start_time_epoch = time.time()
            start_time_step = time.time()
            cer_dev_best = 1
            not_improved_epoch = 0
            learning_rate = float(params['learning_rate'])
            for step, (data, is_new_epoch) in enumerate(train_data):

                # Create feed dictionary for next mini batch (train)
                inputs, labels_train, inputs_seq_len, labels_seq_len, _ = data
                feed_dict_train = {}
                for i_gpu in range(len(gpu_indices)):
                    feed_dict_train[model.inputs_pl_list[i_gpu]
                                    ] = inputs[i_gpu]
                    feed_dict_train[model.labels_pl_list[i_gpu]
                                    ] = labels_train[i_gpu]
                    feed_dict_train[model.inputs_seq_len_pl_list[i_gpu]
                                    ] = inputs_seq_len[i_gpu]
                    feed_dict_train[model.labels_seq_len_pl_list[i_gpu]
                                    ] = labels_seq_len[i_gpu]
                    feed_dict_train[model.keep_prob_encoder_pl_list[i_gpu]
                                    ] = 1 - float(params['dropout_encoder'])
                    feed_dict_train[model.keep_prob_decoder_pl_list[i_gpu]
                                    ] = 1 - float(params['dropout_decoder'])
                    feed_dict_train[model.keep_prob_embedding_pl_list[i_gpu]
                                    ] = 1 - float(params['dropout_embedding'])
                feed_dict_train[learning_rate_pl] = learning_rate

                # Update parameters
                sess.run(train_op, feed_dict=feed_dict_train)

                if (step + 1) % int(params['print_step'] / len(gpu_indices)) == 0:

                    # Create feed dictionary for next mini batch (dev)
                    inputs, labels_dev, inputs_seq_len, labels_seq_len, _ = dev_data.next()[
                        0]
                    feed_dict_dev = {}
                    for i_gpu in range(len(gpu_indices)):
                        feed_dict_dev[model.inputs_pl_list[i_gpu]
                                      ] = inputs[i_gpu]
                        feed_dict_dev[model.labels_pl_list[i_gpu]
                                      ] = labels_dev[i_gpu]
                        feed_dict_dev[model.inputs_seq_len_pl_list[i_gpu]
                                      ] = inputs_seq_len[i_gpu]
                        feed_dict_dev[model.labels_seq_len_pl_list[i_gpu]
                                      ] = labels_seq_len[i_gpu]
                        feed_dict_dev[model.keep_prob_encoder_pl_list[i_gpu]
                                      ] = 1.0
                        feed_dict_dev[model.keep_prob_decoder_pl_list[i_gpu]
                                      ] = 1.0
                        feed_dict_dev[model.keep_prob_embedding_pl_list[i_gpu]
                                      ] = 1.0

                    # Compute loss
                    loss_train = sess.run(
                        loss_op, feed_dict=feed_dict_train)
                    loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev)
                    csv_steps.append(step)
                    csv_loss_train.append(loss_train)
                    csv_loss_dev.append(loss_dev)

                    # Change to evaluation mode
                    for i_gpu in range(len(gpu_indices)):
                        feed_dict_train[model.keep_prob_encoder_pl_list[i_gpu]] = 1.0
                        feed_dict_train[model.keep_prob_decoder_pl_list[i_gpu]] = 1.0
                        feed_dict_train[model.keep_prob_embedding_pl_list[i_gpu]] = 1.0

                    # Predict class ids
                    predicted_ids_train_list, summary_str_train = sess.run(
                        [decode_ops_infer, summary_train], feed_dict=feed_dict_train)
                    predicted_ids_dev_list, summary_str_dev = sess.run(
                        [decode_ops_infer, summary_dev], feed_dict=feed_dict_dev)

                    # Convert to sparsetensor to compute LER
                    feed_dict_ler_train = {}
                    for i_gpu in range(len(gpu_indices)):
                        feed_dict_ler_train[model.labels_st_true_pl_list[i_gpu]] = list2sparsetensor(
                            labels_train[i_gpu],
                            padded_value=train_data.padded_value),
                        feed_dict_ler_train[model.labels_st_pred_pl_list[i_gpu]] = list2sparsetensor(
                            predicted_ids_train_list[i_gpu],
                            padded_value=train_data.padded_value)
                    feed_dict_ler_dev = {}
                    for i_gpu in range(len(gpu_indices)):
                        feed_dict_ler_dev[model.labels_st_true_pl_list[i_gpu]] = list2sparsetensor(
                            labels_dev[i_gpu],
                            padded_value=dev_data.padded_value),
                        feed_dict_ler_dev[model.labels_st_pred_pl_list[i_gpu]] = list2sparsetensor(
                            predicted_ids_dev_list[i_gpu],
                            padded_value=dev_data.padded_value)

                    # Compute accuracy
                    # ler_train = sess.run(ler_op, feed_dict=feed_dict_ler_train)
                    # ler_dev = sess.run(ler_op, feed_dict=feed_dict_ler_dev)
                    ler_train = 1
                    ler_dev = 1
                    csv_ler_train.append(ler_train)
                    csv_ler_dev.append(ler_dev)
                    # TODO: fix this

                    # Update even files
                    summary_writer.add_summary(summary_str_train, step + 1)
                    summary_writer.add_summary(summary_str_dev, step + 1)
                    summary_writer.flush()

                    duration_step = time.time() - start_time_step
                    print("Step %d (epoch: %.3f): loss = %.3f (%.3f) / ler = %.3f (%.3f) / lr = %.5f (%.3f min)" %
                          (step + 1, train_data.epoch_detail, loss_train, loss_dev, ler_train, ler_dev,
                           learning_rate, duration_step / 60))
                    sys.stdout.flush()
                    start_time_step = time.time()

                # Save checkpoint and evaluate model per epoch
                if is_new_epoch:
                    duration_epoch = time.time() - start_time_epoch
                    print('-----EPOCH:%d (%.3f min)-----' %
                          (train_data.epoch, duration_epoch / 60))

                    # Save fugure of loss & ler
                    plot_loss(csv_loss_train, csv_loss_dev, csv_steps,
                              save_path=model.save_path)
                    plot_ler(csv_ler_train, csv_ler_dev, csv_steps,
                             label_type=params['label_type'],
                             save_path=model.save_path)

                    if train_data.epoch >= params['eval_start_epoch']:
                        start_time_eval = time.time()
                        print('=== Dev Data Evaluation ===')
                        cer_dev_epoch = do_eval_cer(
                            session=sess,
                            decode_ops=decode_ops_infer,
                            model=model,
                            dataset=dev_data,
                            label_type=params['label_type'],
                            train_data_size=params['train_data_size'],
                            eval_batch_size=1)
                        print('  CER: %f %%' % (cer_dev_epoch * 100))

                        if cer_dev_epoch < cer_dev_best:
                            cer_dev_best = cer_dev_epoch
                            print('■■■ ↑Best Score (CER)↑ ■■■')

                            # Save model (check point)
                            checkpoint_file = join(
                                model.save_path, 'model.ckpt')
                            save_path = saver.save(
                                sess, checkpoint_file, global_step=train_data.epoch)
                            print("Model saved in file: %s" % save_path)
                        else:
                            not_improved_epoch += 1

                        duration_eval = time.time() - start_time_eval
                        print('Evaluation time: %.3f min' %
                              (duration_eval / 60))

                        # Early stopping
                        if not_improved_epoch == params['not_improved_patient_epoch']:
                            break

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=train_data.epoch,
                            value=cer_dev_epoch)

                    start_time_epoch = time.time()

            duration_train = time.time() - start_time_train
            print('Total time: %.3f hour' % (duration_train / 3600))

            # Training was finished correctly
            with open(join(model.save_path, 'complete.txt'), 'w') as f:
                f.write('')
示例#18
0
def do_train(model, params, gpu_indices):
    """Run training.
    Args:
        model: the model to train
        params (dict): A dictionary of parameters
        gpu_indices (list): GPU indices
    """
    if 'kanji' in params['label_type']:
        map_file_path = '../metrics/mapping_files/' + \
            params['label_type'] + '_' + params['train_data_size'] + '.txt'
    elif 'kana' in params['label_type']:
        map_file_path = '../metrics/mapping_files/' + \
            params['label_type'] + '.txt'

    # Load dataset
    train_data = Dataset(data_type='train',
                         train_data_size=params['train_data_size'],
                         label_type=params['label_type'],
                         map_file_path=map_file_path,
                         batch_size=params['batch_size'],
                         max_epoch=params['num_epoch'],
                         splice=params['splice'],
                         num_stack=params['num_stack'],
                         num_skip=params['num_skip'],
                         sort_utt=True,
                         sort_stop_epoch=params['sort_stop_epoch'],
                         num_gpu=len(gpu_indices))
    dev_data = Dataset(data_type='dev',
                       train_data_size=params['train_data_size'],
                       label_type=params['label_type'],
                       map_file_path=map_file_path,
                       batch_size=params['batch_size'],
                       splice=params['splice'],
                       num_stack=params['num_stack'],
                       num_skip=params['num_skip'],
                       sort_utt=False,
                       num_gpu=len(gpu_indices))

    # Tell TensorFlow that the model will be built into the default graph
    with tf.Graph().as_default(), tf.device('/cpu:0'):

        # Create a variable to track the global step
        global_step = tf.Variable(0, name='global_step', trainable=False)

        # Set optimizer
        learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')
        optimizer = model._set_optimizer(params['optimizer'], learning_rate_pl)

        # Calculate the gradients for each model tower
        total_grads_and_vars, total_losses = [], []
        decode_ops_infer, ler_ops = [], []
        all_devices = ['/gpu:%d' % i_gpu for i_gpu in range(len(gpu_indices))]
        # NOTE: /cpu:0 is prepared for evaluation
        with tf.variable_scope(tf.get_variable_scope()):
            for i_gpu in range(len(all_devices)):
                with tf.device(all_devices[i_gpu]):
                    with tf.name_scope('tower_gpu%d' % i_gpu) as scope:

                        # Define placeholders in each tower
                        model.create_placeholders()

                        # Calculate the total loss for the current tower of the
                        # model. This function constructs the entire model but
                        # shares the variables across all towers.
                        tower_loss, tower_logits, tower_decoder_outputs_train, tower_decoder_outputs_infer = model.compute_loss(
                            model.inputs_pl_list[i_gpu],
                            model.labels_pl_list[i_gpu],
                            model.inputs_seq_len_pl_list[i_gpu],
                            model.labels_seq_len_pl_list[i_gpu],
                            model.keep_prob_encoder_pl_list[i_gpu],
                            model.keep_prob_decoder_pl_list[i_gpu],
                            model.keep_prob_embedding_pl_list[i_gpu], scope)
                        tower_loss = tf.expand_dims(tower_loss, axis=0)
                        total_losses.append(tower_loss)

                        # Reuse variables for the next tower
                        tf.get_variable_scope().reuse_variables()

                        # Calculate the gradients for the batch of data on this
                        # tower
                        tower_grads_and_vars = optimizer.compute_gradients(
                            tower_loss)

                        # Gradient clipping
                        tower_grads_and_vars = model._clip_gradients(
                            tower_grads_and_vars)

                        # TODO: Optionally add gradient noise

                        # Keep track of the gradients across all towers
                        total_grads_and_vars.append(tower_grads_and_vars)

                        # Add to the graph each operation per tower
                        _, decode_op_tower_infer = model.decode(
                            tower_decoder_outputs_train,
                            tower_decoder_outputs_infer)
                        decode_ops_infer.append(decode_op_tower_infer)
                        # ler_op_tower = model.compute_ler(
                        #     decode_op_tower, model.labels_pl_list[i_gpu])
                        ler_op_tower = model.compute_ler(
                            model.labels_st_true_pl_list[i_gpu],
                            model.labels_st_pred_pl_list[i_gpu])
                        ler_op_tower = tf.expand_dims(ler_op_tower, axis=0)
                        ler_ops.append(ler_op_tower)

        # Aggregate losses, then calculate average loss
        total_losses = tf.concat(axis=0, values=total_losses)
        loss_op = tf.reduce_mean(total_losses, axis=0)
        ler_ops = tf.concat(axis=0, values=ler_ops)
        ler_op = tf.reduce_mean(ler_ops, axis=0)

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers
        average_grads_and_vars = average_gradients(total_grads_and_vars)

        # Apply the gradients to adjust the shared variables.
        train_op = optimizer.apply_gradients(average_grads_and_vars,
                                             global_step=global_step)

        # Define learning rate controller
        lr_controller = Controller(
            learning_rate_init=params['learning_rate'],
            decay_start_epoch=params['decay_start_epoch'],
            decay_rate=params['decay_rate'],
            decay_patient_epoch=params['decay_patient_epoch'],
            lower_better=True)

        # Build the summary tensor based on the TensorFlow collection of
        # summaries
        summary_train = tf.summary.merge(model.summaries_train)
        summary_dev = tf.summary.merge(model.summaries_dev)

        # Add the variable initializer operation
        init_op = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints
        saver = tf.train.Saver(max_to_keep=None)

        # Count total parameters
        parameters_dict, total_parameters = count_total_parameters(
            tf.trainable_variables())
        for parameter_name in sorted(parameters_dict.keys()):
            print("%s %d" % (parameter_name, parameters_dict[parameter_name]))
        print("Total %d variables, %s M parameters" %
              (len(parameters_dict.keys()), "{:,}".format(
                  total_parameters / 1000000)))

        csv_steps, csv_loss_train, csv_loss_dev = [], [], []
        csv_ler_train, csv_ler_dev = [], []
        # Create a session for running operation on the graph
        # NOTE: Start running operations on the Graph. allow_soft_placement
        # must be set to True to build towers on GPU, as some of the ops do not
        # have GPU implementations.
        with tf.Session(
                config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False)) as sess:

            # Instantiate a SummaryWriter to output summaries and the graph
            summary_writer = tf.summary.FileWriter(model.save_path, sess.graph)

            # Initialize param
            sess.run(init_op)

            # Train model
            start_time_train = time.time()
            start_time_epoch = time.time()
            start_time_step = time.time()
            cer_dev_best = 1
            not_improved_epoch = 0
            learning_rate = float(params['learning_rate'])
            for step, (data, is_new_epoch) in enumerate(train_data):

                # Create feed dictionary for next mini batch (train)
                inputs, labels_train, inputs_seq_len, labels_seq_len, _ = data
                feed_dict_train = {}
                for i_gpu in range(len(gpu_indices)):
                    feed_dict_train[
                        model.inputs_pl_list[i_gpu]] = inputs[i_gpu]
                    feed_dict_train[
                        model.labels_pl_list[i_gpu]] = labels_train[i_gpu]
                    feed_dict_train[model.inputs_seq_len_pl_list[
                        i_gpu]] = inputs_seq_len[i_gpu]
                    feed_dict_train[model.labels_seq_len_pl_list[
                        i_gpu]] = labels_seq_len[i_gpu]
                    feed_dict_train[
                        model.keep_prob_encoder_pl_list[i_gpu]] = 1 - float(
                            params['dropout_encoder'])
                    feed_dict_train[
                        model.keep_prob_decoder_pl_list[i_gpu]] = 1 - float(
                            params['dropout_decoder'])
                    feed_dict_train[
                        model.keep_prob_embedding_pl_list[i_gpu]] = 1 - float(
                            params['dropout_embedding'])
                feed_dict_train[learning_rate_pl] = learning_rate

                # Update parameters
                sess.run(train_op, feed_dict=feed_dict_train)

                if (step + 1) % int(
                        params['print_step'] / len(gpu_indices)) == 0:

                    # Create feed dictionary for next mini batch (dev)
                    inputs, labels_dev, inputs_seq_len, labels_seq_len, _ = dev_data.next(
                    )[0]
                    feed_dict_dev = {}
                    for i_gpu in range(len(gpu_indices)):
                        feed_dict_dev[
                            model.inputs_pl_list[i_gpu]] = inputs[i_gpu]
                        feed_dict_dev[
                            model.labels_pl_list[i_gpu]] = labels_dev[i_gpu]
                        feed_dict_dev[model.inputs_seq_len_pl_list[
                            i_gpu]] = inputs_seq_len[i_gpu]
                        feed_dict_dev[model.labels_seq_len_pl_list[
                            i_gpu]] = labels_seq_len[i_gpu]
                        feed_dict_dev[
                            model.keep_prob_encoder_pl_list[i_gpu]] = 1.0
                        feed_dict_dev[
                            model.keep_prob_decoder_pl_list[i_gpu]] = 1.0
                        feed_dict_dev[
                            model.keep_prob_embedding_pl_list[i_gpu]] = 1.0

                    # Compute loss
                    loss_train = sess.run(loss_op, feed_dict=feed_dict_train)
                    loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev)
                    csv_steps.append(step)
                    csv_loss_train.append(loss_train)
                    csv_loss_dev.append(loss_dev)

                    # Change to evaluation mode
                    for i_gpu in range(len(gpu_indices)):
                        feed_dict_train[
                            model.keep_prob_encoder_pl_list[i_gpu]] = 1.0
                        feed_dict_train[
                            model.keep_prob_decoder_pl_list[i_gpu]] = 1.0
                        feed_dict_train[
                            model.keep_prob_embedding_pl_list[i_gpu]] = 1.0

                    # Predict class ids
                    predicted_ids_train_list, summary_str_train = sess.run(
                        [decode_ops_infer, summary_train],
                        feed_dict=feed_dict_train)
                    predicted_ids_dev_list, summary_str_dev = sess.run(
                        [decode_ops_infer, summary_dev],
                        feed_dict=feed_dict_dev)

                    # Convert to sparsetensor to compute LER
                    feed_dict_ler_train = {}
                    for i_gpu in range(len(gpu_indices)):
                        feed_dict_ler_train[model.labels_st_true_pl_list[
                            i_gpu]] = list2sparsetensor(
                                labels_train[i_gpu],
                                padded_value=train_data.padded_value),
                        feed_dict_ler_train[model.labels_st_pred_pl_list[
                            i_gpu]] = list2sparsetensor(
                                predicted_ids_train_list[i_gpu],
                                padded_value=train_data.padded_value)
                    feed_dict_ler_dev = {}
                    for i_gpu in range(len(gpu_indices)):
                        feed_dict_ler_dev[model.labels_st_true_pl_list[
                            i_gpu]] = list2sparsetensor(
                                labels_dev[i_gpu],
                                padded_value=dev_data.padded_value),
                        feed_dict_ler_dev[model.labels_st_pred_pl_list[
                            i_gpu]] = list2sparsetensor(
                                predicted_ids_dev_list[i_gpu],
                                padded_value=dev_data.padded_value)

                    # Compute accuracy
                    # ler_train = sess.run(ler_op, feed_dict=feed_dict_ler_train)
                    # ler_dev = sess.run(ler_op, feed_dict=feed_dict_ler_dev)
                    ler_train = 1
                    ler_dev = 1
                    csv_ler_train.append(ler_train)
                    csv_ler_dev.append(ler_dev)
                    # TODO: fix this

                    # Update even files
                    summary_writer.add_summary(summary_str_train, step + 1)
                    summary_writer.add_summary(summary_str_dev, step + 1)
                    summary_writer.flush()

                    duration_step = time.time() - start_time_step
                    print(
                        "Step %d (epoch: %.3f): loss = %.3f (%.3f) / ler = %.3f (%.3f) / lr = %.5f (%.3f min)"
                        % (step + 1, train_data.epoch_detail, loss_train,
                           loss_dev, ler_train, ler_dev, learning_rate,
                           duration_step / 60))
                    sys.stdout.flush()
                    start_time_step = time.time()

                # Save checkpoint and evaluate model per epoch
                if is_new_epoch:
                    duration_epoch = time.time() - start_time_epoch
                    print('-----EPOCH:%d (%.3f min)-----' %
                          (train_data.epoch, duration_epoch / 60))

                    # Save fugure of loss & ler
                    plot_loss(csv_loss_train,
                              csv_loss_dev,
                              csv_steps,
                              save_path=model.save_path)
                    plot_ler(csv_ler_train,
                             csv_ler_dev,
                             csv_steps,
                             label_type=params['label_type'],
                             save_path=model.save_path)

                    if train_data.epoch >= params['eval_start_epoch']:
                        start_time_eval = time.time()
                        print('=== Dev Data Evaluation ===')
                        cer_dev_epoch = do_eval_cer(
                            session=sess,
                            decode_ops=decode_ops_infer,
                            model=model,
                            dataset=dev_data,
                            label_type=params['label_type'],
                            train_data_size=params['train_data_size'],
                            eval_batch_size=1)
                        print('  CER: %f %%' % (cer_dev_epoch * 100))

                        if cer_dev_epoch < cer_dev_best:
                            cer_dev_best = cer_dev_epoch
                            print('■■■ ↑Best Score (CER)↑ ■■■')

                            # Save model (check point)
                            checkpoint_file = join(model.save_path,
                                                   'model.ckpt')
                            save_path = saver.save(
                                sess,
                                checkpoint_file,
                                global_step=train_data.epoch)
                            print("Model saved in file: %s" % save_path)
                        else:
                            not_improved_epoch += 1

                        duration_eval = time.time() - start_time_eval
                        print('Evaluation time: %.3f min' %
                              (duration_eval / 60))

                        # Early stopping
                        if not_improved_epoch == params[
                                'not_improved_patient_epoch']:
                            break

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=train_data.epoch,
                            value=cer_dev_epoch)

                    start_time_epoch = time.time()

            duration_train = time.time() - start_time_train
            print('Total time: %.3f hour' % (duration_train / 3600))

            # Training was finished correctly
            with open(join(model.save_path, 'complete.txt'), 'w') as f:
                f.write('')
示例#19
0
def do_train(network, optimizer, learning_rate, batch_size, epoch_num,
             label_type, num_stack, num_skip, gpu_indices):
    """Run training. If target labels are phone, the model is evaluated by PER
    with 39 phones.
    Args:
        network: network to train
        optimizer: string, the name of optimizer.
            ex.) adam, rmsprop
        learning_rate: A flaot value, the initial learning rate
        batch_size: int, teh the size of mini-batch
        epoch_num: int, the number of epochs to train
        label_type: string, phone39 or phone48 or phone61 or character
        num_stack: int, the number of frames to stack
        num_skip: int, the number of frames to skip
        gpu_indices: list of integer
    """
    # Load dataset
    train_data = DataSet(data_type='train',
                         label_type=label_type,
                         batch_size=batch_size,
                         num_stack=num_stack,
                         num_skip=num_skip,
                         is_sorted=True,
                         num_gpu=len(gpu_indices))
    dev_data = DataSet(data_type='dev',
                       label_type=label_type,
                       batch_size=batch_size,
                       num_stack=num_stack,
                       num_skip=num_skip,
                       is_sorted=False,
                       num_gpu=len(gpu_indices))
    if label_type == 'character':
        # TODO: evaluationのときはどうする?
        test_data = DataSet(data_type='test',
                            label_type='character',
                            batch_size=batch_size,
                            num_stack=num_stack,
                            num_skip=num_skip,
                            is_sorted=False,
                            num_gpu=1)
    else:

        test_data = DataSet(data_type='test',
                            label_type='phone39',
                            batch_size=batch_size,
                            num_stack=num_stack,
                            num_skip=num_skip,
                            is_sorted=False,
                            num_gpu=1)

    # Tell TensorFlow that the model will be built into the default graph
    with tf.Graph().as_default(), tf.device('/cpu:0'):

        # Create a variable to track the global step
        global_step = tf.Variable(0, name='global_step', trainable=False)
        optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate)

        # Calculate the gradients for each model tower
        tower_grads = []
        network.inputs = []
        network.labels = []
        network.inputs_seq_len = []
        network.keep_prob_input = []
        network.keep_prob_hidden = []
        # TODO: cpu 用のタワーも用意する

        all_devices = ['/gpu:%d' % i_gpu for i_gpu in range(len(gpu_indices))]
        # NOTE: /cpu:0 is prepared for evaluation

        loss_dict = {}
        total_loss = []
        with tf.variable_scope(tf.get_variable_scope()):
            for i_device in range(len(all_devices)):
                with tf.device(all_devices[i_device]):
                    with tf.name_scope('%s_%d' % ('tower', i_device)) as scope:

                        # Define placeholders in each tower
                        network.inputs.append(
                            tf.placeholder(
                                tf.float32,
                                shape=[None, None, network.input_size],
                                name='input' + str(i_device)))
                        indices_pl = tf.placeholder(tf.int64,
                                                    name='indices%d' %
                                                    i_device)
                        values_pl = tf.placeholder(tf.int32,
                                                   name='values%d' % i_device)
                        shape_pl = tf.placeholder(tf.int64,
                                                  name='shape%d' % i_device)
                        network.labels.append(
                            tf.SparseTensor(indices_pl, values_pl, shape_pl))
                        network.inputs_seq_len.append(
                            tf.placeholder(tf.int64,
                                           shape=[None],
                                           name='inputs_seq_len%d' % i_device))
                        network.keep_prob_input.append(
                            tf.placeholder(tf.float32,
                                           name='keep_prob_input%d' %
                                           i_device))
                        network.keep_prob_hidden.append(
                            tf.placeholder(tf.float32,
                                           name='keep_prob_hidden%d' %
                                           i_device))

                        # Calculate the loss for one tower of the model. This
                        # function constructs the entire model but shares the
                        # variables across all towers
                        loss, logits = network.compute_loss(
                            network.inputs[i_device], network.labels[i_device],
                            network.inputs_seq_len[i_device],
                            network.keep_prob_input[i_device],
                            network.keep_prob_hidden[i_device])

                        # Assemble all of the losses for the current tower
                        # only
                        losses = tf.get_collection('losses', scope)

                        # Calculate the total loss for the current tower
                        tower_loss = tf.add_n(losses, name='tower_loss')
                        total_loss.append(tower_loss)

                        # Reuse variables for the next tower
                        tf.get_variable_scope().reuse_variables()

                        # Retain the summaries from the final tower.
                        # summaries = tf.get_collection(
                        #     tf.GraphKeys.SUMMARIES, scope)

                        # Calculate the gradients for the batch of data on this
                        # tower.
                        grads = optimizer.compute_gradients(tower_loss)

                        # TODO: gradient clipping

                        # Keep track of the gradients across all towers.
                        tower_grads.append(grads)

        # Aggregate losses, then calculate average loss.
        loss_op = tf.add_n(total_loss) / len(gpu_indices)

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers
        # for i in range(len(gpu_indices)):
        grads = average_gradients(tower_grads)

        # Add a summary to track the learning rate.
        # summaries.append(tf.summary.scalar('learning_rate', lr))

        # Add histograms for gradients.
        # for grad, var in grads:
        #   if grad is not None:
        # summaries.append(tf.summary.histogram(var.op.name + '/gradients',
        # grad))

        # Apply the gradients to adjust the shared variables.
        train_op = optimizer.apply_gradients(grads, global_step=global_step)

        # Add histograms for trainable variables.
        # for var in tf.trainable_variables():
        #   summaries.append(tf.summary.histogram(var.op.name, var))

        # Track the moving averages of all trainable variables.
        # variable_averages = tf.train.ExponentialMovingAverage(
        #     0.9999, global_step)
        # variables_averages_op =
        # variable_averages.apply(tf.trainable_variables())

        # Group all updates to into a single train op.
        # train_op = tf.group(apply_gradient_op, variables_averages_op)

        ####################################

        # Add to the graph each operation (Use last placeholders)
        # train_op = network.train(loss_op,
        #                          optimizer='adam',
        #                          learning_rate_init=learning_rate,
        #                          is_scheduled=False)
        decode_op = network.decoder(logits,
                                    network.inputs_seq_len[-1],
                                    decode_type='beam_search',
                                    beam_width=20)
        ler_op = network.compute_ler(decode_op, network.labels[-1])

        # Build the summary tensor based on the TensorFlow collection of
        # summaries
        summary_train = tf.summary.merge(network.summaries_train)
        summary_dev = tf.summary.merge(network.summaries_dev)

        # Add the variable initializer operation
        init_op = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints
        saver = tf.train.Saver(max_to_keep=None)

        # Count total parameters
        parameters_dict, total_parameters = count_total_parameters(
            tf.trainable_variables())
        for parameter_name in sorted(parameters_dict.keys()):
            print("%s %d" % (parameter_name, parameters_dict[parameter_name]))
        print("Total %d variables, %s M parameters" %
              (len(parameters_dict.keys()), "{:,}".format(
                  total_parameters / 1000000)))

        csv_steps, csv_loss_train, csv_loss_dev = [], [], []
        csv_ler_train, csv_ler_dev = [], []
        # Create a session for running operation on the graph
        # NOTE: Start running operations on the Graph. allow_soft_placement
        # must be set to True to build towers on GPU, as some of the ops do not
        # have GPU implementations.
        with tf.Session(
                config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False)) as sess:

            # Instantiate a SummaryWriter to output summaries and the graph
            summary_writer = tf.summary.FileWriter(network.model_dir,
                                                   sess.graph)

            # Initialize parameters
            sess.run(init_op)

            # Make generator
            mini_batch_train = train_data.next_batch(session=sess)
            mini_batch_dev = dev_data.next_batch(session=sess)

            # Train model
            iter_per_epoch = int(train_data.data_num /
                                 (batch_size * len(gpu_indices)))
            train_step = train_data.data_num / batch_size
            if train_step != int(train_step):
                iter_per_epoch += 1
            max_steps = iter_per_epoch * epoch_num
            start_time_train = time.time()
            start_time_epoch = time.time()
            start_time_step = time.time()
            error_best = 1
            for step in range(max_steps):

                # Create feed dictionary for next mini batch (train)
                inputs, labels_st, inputs_seq_len, _ = mini_batch_train.__next__(
                )
                feed_dict_train, feed_dict_dev = {}, {}
                for i_gpu in range(len(gpu_indices)):
                    feed_dict_train[network.inputs[i_gpu]] = inputs[i_gpu]
                    feed_dict_train[network.labels[i_gpu]] = labels_st[i_gpu]
                    feed_dict_train[
                        network.inputs_seq_len[i_gpu]] = inputs_seq_len[i_gpu]
                    feed_dict_train[network.keep_prob_input[
                        i_gpu]] = network.dropout_ratio_input
                    feed_dict_train[network.keep_prob_hidden[
                        i_gpu]] = network.dropout_ratio_hidden

                # Create feed dictionary for next mini batch (dev)
                inputs, labels_st, inputs_seq_len, _ = mini_batch_dev.__next__(
                )
                for i_gpu in range(len(gpu_indices)):
                    feed_dict_dev[network.inputs[i_gpu]] = inputs[i_gpu]
                    feed_dict_dev[network.labels[i_gpu]] = labels_st[i_gpu]
                    feed_dict_dev[
                        network.inputs_seq_len[i_gpu]] = inputs_seq_len[i_gpu]
                    feed_dict_dev[network.keep_prob_input[
                        i_gpu]] = network.dropout_ratio_input
                    feed_dict_dev[network.keep_prob_hidden[
                        i_gpu]] = network.dropout_ratio_hidden

                # Update parameters
                sess.run(train_op, feed_dict=feed_dict_train)

                if (step + 1) % int(10 / len(gpu_indices)) == 0:

                    # Compute loss
                    print(loss_op)
                    loss_train = sess.run(loss_op, feed_dict=feed_dict_train)
                    # loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev)
                    csv_steps.append(step)
                    csv_loss_train.append(loss_train)
                    # csv_loss_dev.append(loss_dev)

                    # Change to evaluation mode
                    for i_gpu in range(len(gpu_indices)):
                        feed_dict_train[network.keep_prob_input[i_gpu]] = 1.0
                        feed_dict_train[network.keep_prob_hidden[i_gpu]] = 1.0
                        feed_dict_dev[network.keep_prob_input[i_gpu]] = 1.0
                        feed_dict_dev[network.keep_prob_hidden[i_gpu]] = 1.0

                    # Compute accuracy & update event file
                    ler_train, summary_str_train = sess.run(
                        [ler_op, summary_train], feed_dict=feed_dict_train)
                    # ler_dev, summary_str_dev = sess.run(
                    #     [ler_op, summary_dev], feed_dict=feed_dict_dev)
                    csv_ler_train.append(ler_train)
                    # csv_ler_dev.append(ler_dev)
                    summary_writer.add_summary(summary_str_train, step + 1)
                    # summary_writer.add_summary(summary_str_dev, step + 1)
                    summary_writer.flush()

                    duration_step = time.time() - start_time_step
                    print(
                        "Step %d: loss = %.3f (%.3f) / ler = %.4f (%.4f) (%.3f min)"
                        % (step + 1, loss_train, 1, ler_train, 1,
                           duration_step / 60))
                    sys.stdout.flush()
                    start_time_step = time.time()

                # Save checkpoint and evaluate model per epoch
                if (step + 1) % iter_per_epoch == 0 or (step + 1) == max_steps:
                    duration_epoch = time.time() - start_time_epoch
                    epoch = (step + 1) // iter_per_epoch
                    print('-----EPOCH:%d (%.3f min)-----' %
                          (epoch, duration_epoch / 60))

                    # Save model (check point)
                    checkpoint_file = join(network.model_dir, 'model.ckpt')
                    # save_path = saver.save(
                    #     sess, checkpoint_file, global_step=epoch)
                    # print("Model saved in file: %s" % save_path)

                    if epoch >= 10:
                        start_time_eval = time.time()
                        # if label_type == 'character':
                        #     print('=== Dev Data Evaluation ===')
                        #     cer_dev_epoch = do_eval_cer(
                        #         session=sess,
                        #         decode_op=decode_op,
                        #         network=network,
                        #         dataset=dev_data)
                        #     print('  CER: %f %%' % (cer_dev_epoch * 100))
                        #
                        #     if cer_dev_epoch < error_best:
                        #         error_best = cer_dev_epoch
                        #         print('■■■ ↑Best Score (CER)↑ ■■■')
                        #
                        #         print('=== Test Data Evaluation ===')
                        #         cer_test = do_eval_cer(
                        #             session=sess,
                        #             decode_op=decode_op,
                        #             network=network,
                        #             dataset=test_data,
                        #             eval_batch_size=1)
                        #         print('  CER: %f %%' % (cer_test * 100))
                        #
                        # else:
                        #     print('=== Dev Data Evaluation ===')
                        #     per_dev_epoch = do_eval_per(
                        #         session=sess,
                        #         decode_op=decode_op,
                        #         per_op=ler_op,
                        #         network=network,
                        #         dataset=dev_data,
                        #         label_type=label_type)
                        #     print('  PER: %f %%' % (per_dev_epoch * 100))
                        #
                        #     if per_dev_epoch < error_best:
                        #         error_best = per_dev_epoch
                        #         print('■■■ ↑Best Score (PER)↑ ■■■')
                        #
                        #         print('=== Test Data Evaluation ===')
                        #         per_test = do_eval_per(
                        #             session=sess,
                        #             decode_op=decode_op,
                        #             per_op=ler_op,
                        #             network=network,
                        #             dataset=test_data,
                        #             label_type=label_type,
                        #             eval_batch_size=1)
                        #         print('  PER: %f %%' % (per_test * 100))

                        duration_eval = time.time() - start_time_eval
                        print('Evaluation time: %.3f min' %
                              (duration_eval / 60))

                start_time_epoch = time.time()
                start_time_step = time.time()

            duration_train = time.time() - start_time_train
            print('Total time: %.3f hour' % (duration_train / 3600))

            # Save train & dev loss
            save_loss(csv_steps,
                      csv_loss_train,
                      csv_loss_dev,
                      save_path=network.model_dir)
            save_ler(csv_steps,
                     csv_ler_train,
                     csv_ler_dev,
                     save_path=network.model_dir)

            # Training was finished correctly
            with open(join(network.model_dir, 'complete.txt'), 'w') as f:
                f.write('')
def do_train(model, params):
    """Run training.
    Args:
        model: model to train
        params: A dictionary of parameters
    """
    # Load dataset
    train_data = Dataset(data_type='train',
                         label_type_main=params['label_type_main'],
                         label_type_sub=params['label_type_sub'],
                         train_data_size=params['train_data_size'],
                         batch_size=params['batch_size'],
                         num_stack=params['num_stack'],
                         num_skip=params['num_skip'],
                         sort_utt=True)
    dev_data_step = Dataset(data_type='dev',
                            label_type_main=params['label_type_main'],
                            label_type_sub=params['label_type_sub'],
                            train_data_size=params['train_data_size'],
                            batch_size=params['batch_size'],
                            num_stack=params['num_stack'],
                            num_skip=params['num_skip'],
                            sort_utt=False)
    dev_data_epoch = Dataset(data_type='dev',
                             label_type_main=params['label_type_main'],
                             label_type_sub=params['label_type_sub'],
                             train_data_size=params['train_data_size'],
                             batch_size=params['batch_size'],
                             num_stack=params['num_stack'],
                             num_skip=params['num_skip'],
                             sort_utt=False)

    # Tell TensorFlow that the model will be built into the default graph
    with tf.Graph().as_default():

        # Define placeholders
        model.create_placeholders(gpu_index=0)

        # Add to the graph each operation
        loss_op, logits_main, logits_sub = model.compute_loss(
            model.inputs_pl_list[0],
            model.labels_pl_list[0],
            model.labels_sub_pl_list[0],
            model.inputs_seq_len_pl_list[0],
            model.keep_prob_input_pl_list[0],
            model.keep_prob_hidden_pl_list[0],
            model.keep_prob_output_pl_list[0])
        train_op = model.train(loss_op,
                               optimizer=params['optimizer'],
                               learning_rate=model.learning_rate_pl_list[0])
        decode_op_main, decode_op_sub = model.decoder(
            logits_main,
            logits_sub,
            model.inputs_seq_len_pl_list[0],
            decode_type='beam_search',
            beam_width=20)
        ler_op_main, ler_op_sub = model.compute_ler(
            decode_op_main, decode_op_sub,
            model.labels_pl_list[0], model.labels_sub_pl_list[0])

        # Define learning rate controller
        lr_controller = Controller(
            learning_rate_init=params['learning_rate'],
            decay_start_epoch=params['decay_start_epoch'],
            decay_rate=params['decay_rate'],
            decay_patient_epoch=1,
            lower_better=True)

        # Build the summary tensor based on the TensorFlow collection of
        # summaries
        summary_train = tf.summary.merge(model.summaries_train)
        summary_dev = tf.summary.merge(model.summaries_dev)

        # Add the variable initializer operation
        init_op = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints
        saver = tf.train.Saver(max_to_keep=None)

        # Count total parameters
        parameters_dict, total_parameters = count_total_parameters(
            tf.trainable_variables())
        for parameter_name in sorted(parameters_dict.keys()):
            print("%s %d" % (parameter_name, parameters_dict[parameter_name]))
        print("Total %d variables, %s M parameters" %
              (len(parameters_dict.keys()),
               "{:,}".format(total_parameters / 1000000)))

        csv_steps, csv_loss_train, csv_loss_dev = [], [], []
        csv_ler_main_train, csv_ler_main_dev = [], []
        csv_ler_sub_train, csv_ler_sub_dev = [], []
        # Create a session for running operation on the graph
        with tf.Session() as sess:

            # Instantiate a SummaryWriter to output summaries and the graph
            summary_writer = tf.summary.FileWriter(
                model.save_path, sess.graph)

            # Initialize parameters
            sess.run(init_op)

            # Make mini-batch generator
            mini_batch_train = train_data.next_batch()
            mini_batch_dev = dev_data_step.next_batch()

            # Train model
            iter_per_epoch = int(train_data.data_num / params['batch_size'])
            train_step = train_data.data_num / params['batch_size']
            if (train_step) != int(train_step):
                iter_per_epoch += 1
            max_steps = iter_per_epoch * params['num_epoch']
            start_time_train = time.time()
            start_time_epoch = time.time()
            start_time_step = time.time()
            ler_main_dev_best = 1
            learning_rate = float(params['learning_rate'])
            for step in range(max_steps):

                # Create feed dictionary for next mini batch (train)
                inputs, labels_main, labels_sub, inputs_seq_len, _ = mini_batch_train.__next__()
                feed_dict_train = {
                    model.inputs_pl_list[0]: inputs,
                    model.labels_pl_list[0]: list2sparsetensor(labels_main, padded_value=-1),
                    model.labels_sub_pl_list[0]: list2sparsetensor(labels_sub, padded_value=-1),
                    model.inputs_seq_len_pl_list[0]: inputs_seq_len,
                    model.keep_prob_input_pl_list[0]: model.dropout_ratio_input,
                    model.keep_prob_hidden_pl_list[0]: model.dropout_ratio_hidden,
                    model.keep_prob_output_pl_list[0]: model.dropout_ratio_output,
                    model.learning_rate_pl_list[0]: learning_rate
                }

                # Update parameters
                sess.run(train_op, feed_dict=feed_dict_train)

                if (step + 1) % 200 == 0:

                    # Create feed dictionary for next mini batch (dev)
                    inputs, labels_main, labels_sub, inputs_seq_len, _ = mini_batch_dev.__next__()
                    feed_dict_dev = {
                        model.inputs_pl_list[0]: inputs,
                        model.labels_pl_list[0]: list2sparsetensor(labels_main, padded_value=-1),
                        model.labels_sub_pl_list[0]: list2sparsetensor(labels_sub, padded_value=-1),
                        model.inputs_seq_len_pl_list[0]: inputs_seq_len,
                        model.keep_prob_input_pl_list[0]: 1.0,
                        model.keep_prob_hidden_pl_list[0]: 1.0,
                        model.keep_prob_output_pl_list[0]: 1.0
                    }

                    # Compute loss
                    loss_train = sess.run(loss_op, feed_dict=feed_dict_train)
                    loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev)
                    csv_steps.append(step)
                    csv_loss_train.append(loss_train)
                    csv_loss_dev.append(loss_dev)

                    # Change to evaluation mode
                    feed_dict_train[model.keep_prob_input_pl_list[0]] = 1.0
                    feed_dict_train[model.keep_prob_hidden_pl_list[0]] = 1.0
                    feed_dict_train[model.keep_prob_output_pl_list[0]] = 1.0

                    # Compute accuracy & update event file
                    ler_main_train, ler_sub_train, summary_str_train = sess.run(
                        [ler_op_main, ler_op_sub, summary_train],
                        feed_dict=feed_dict_train)
                    ler_main_dev, ler_sub_dev, summary_str_dev = sess.run(
                        [ler_op_main, ler_op_sub,  summary_dev],
                        feed_dict=feed_dict_dev)
                    csv_ler_main_train.append(ler_main_train)
                    csv_ler_main_dev.append(ler_main_dev)
                    csv_ler_sub_train.append(ler_sub_train)
                    csv_ler_sub_dev.append(ler_sub_dev)
                    summary_writer.add_summary(summary_str_train, step + 1)
                    summary_writer.add_summary(summary_str_dev, step + 1)
                    summary_writer.flush()

                    duration_step = time.time() - start_time_step
                    print('Step %d: loss = %.3f (%.3f) / ler_main = %.4f (%.4f) / ler_sub = %.4f (%.4f) (%.3f min)' %
                          (step + 1, loss_train, loss_dev, ler_main_train, ler_main_dev,
                           ler_sub_train, ler_sub_dev, duration_step / 60))
                    sys.stdout.flush()
                    start_time_step = time.time()

                # Save checkpoint and evaluate model per epoch
                if (step + 1) % iter_per_epoch == 0 or (step + 1) == max_steps:
                    duration_epoch = time.time() - start_time_epoch
                    epoch = (step + 1) // iter_per_epoch
                    print('-----EPOCH:%d (%.3f min)-----' %
                          (epoch, duration_epoch / 60))

                    # Save model (check point)
                    checkpoint_file = join(model.save_path, 'model.ckpt')
                    save_path = saver.save(
                        sess, checkpoint_file, global_step=epoch)
                    print("Model saved in file: %s" % save_path)

                    if epoch >= 5:
                        start_time_eval = time.time()
                        print('=== Dev Evaluation ===')
                        ler_main_dev_epoch = do_eval_cer(
                            session=sess,
                            decode_op=decode_op_main,
                            model=model,
                            dataset=dev_data_epoch,
                            label_type=params['label_type_main'],
                            eval_batch_size=params['batch_size'],
                            is_multitask=True,
                            is_main=True)
                        print('  CER (main): %f %%' %
                              (ler_main_dev_epoch * 100))

                        ler_sub_dev_epoch = do_eval_cer(
                            session=sess,
                            decode_op=decode_op_sub,
                            model=model,
                            dataset=dev_data_epoch,
                            label_type=params['label_type_sub'],
                            eval_batch_size=params['batch_size'],
                            is_multitask=True,
                            is_main=False)
                        print('  CER (sub): %f %%' %
                              (ler_sub_dev_epoch * 100))

                        if ler_main_dev_epoch < ler_main_dev_best:
                            ler_main_dev_best = ler_main_dev_epoch
                            print('■■■ ↑Best Score (CER main)↑ ■■■')

                        duration_eval = time.time() - start_time_eval
                        print('Evaluation time: %.3f min' %
                              (duration_eval / 60))

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=epoch,
                            value=ler_main_dev_epoch)

                    start_time_epoch = time.time()
                    start_time_step = time.time()

            duration_train = time.time() - start_time_train
            print('Total time: %.3f hour' % (duration_train / 3600))

            # Save train & dev loss, ler
            save_loss(csv_steps, csv_loss_train, csv_loss_dev,
                      save_path=model.save_path)
            save_ler(csv_steps, csv_ler_main_train, csv_ler_sub_dev,
                     save_path=model.save_path)
            save_ler(csv_steps, csv_ler_sub_train, csv_ler_sub_dev,
                     save_path=model.save_path)

            # Training was finished correctly
            with open(join(model.save_path, 'complete.txt'), 'w') as f:
                f.write('')