def check(self, encoder_type, lstm_impl, time_major=False):

        print('==================================================')
        print('  encoder_type: %s' % str(encoder_type))
        print('  lstm_impl: %s' % str(lstm_impl))
        print('  time_major: %s' % str(time_major))
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 2
            inputs, labels_char, labels_phone, inputs_seq_len = generate_data(
                label_type='multitask', model='ctc', batch_size=batch_size)

            # Define model graph
            num_classes_main = 27
            num_classes_sub = 61
            model = MultitaskCTC(
                encoder_type=encoder_type,
                input_size=inputs[0].shape[1],
                num_units=256,
                num_layers_main=2,
                num_layers_sub=1,
                num_classes_main=num_classes_main,
                num_classes_sub=num_classes_sub,
                main_task_weight=0.8,
                lstm_impl=lstm_impl,
                parameter_init=0.1,
                clip_grad_norm=5.0,
                clip_activation=50,
                num_proj=256,
                weight_decay=1e-8,
                # bottleneck_dim=50,
                bottleneck_dim=None,
                time_major=time_major)

            # Define placeholders
            model.create_placeholders()
            learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

            # Add to the graph each operation
            loss_op, logits_main, logits_sub = model.compute_loss(
                model.inputs_pl_list[0], model.labels_pl_list[0],
                model.labels_sub_pl_list[0], model.inputs_seq_len_pl_list[0],
                model.keep_prob_pl_list[0])
            train_op = model.train(loss_op,
                                   optimizer='adam',
                                   learning_rate=learning_rate_pl)
            decode_op_main, decode_op_sub = model.decoder(
                logits_main,
                logits_sub,
                model.inputs_seq_len_pl_list[0],
                beam_width=20)
            ler_op_main, ler_op_sub = model.compute_ler(
                decode_op_main, decode_op_sub, model.labels_pl_list[0],
                model.labels_sub_pl_list[0])

            # Define learning rate controller
            learning_rate = 1e-3
            lr_controller = Controller(learning_rate_init=learning_rate,
                                       decay_start_epoch=20,
                                       decay_rate=0.9,
                                       decay_patient_epoch=5,
                                       lower_better=True)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()), "{:,}".format(
                      total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                model.inputs_pl_list[0]:
                inputs,
                model.labels_pl_list[0]:
                list2sparsetensor(labels_char, padded_value=-1),
                model.labels_sub_pl_list[0]:
                list2sparsetensor(labels_phone, padded_value=-1),
                model.inputs_seq_len_pl_list[0]:
                inputs_seq_len,
                model.keep_prob_pl_list[0]:
                0.9,
                learning_rate_pl:
                learning_rate
            }

            idx2phone = Idx2phone(map_file_path='./phone61.txt')

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Wrapper for tfdbg
                # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

                # Train model
                max_steps = 1000
                start_time_step = time.time()
                for step in range(max_steps):

                    # Compute loss
                    _, loss_train = sess.run([train_op, loss_op],
                                             feed_dict=feed_dict)

                    # Gradient check
                    # grads = sess.run(model.clipped_grads,
                    #                  feed_dict=feed_dict)
                    # for grad in grads:
                    #     print(np.max(grad))

                    if (step + 1) % 10 == 0:
                        # Change to evaluation mode
                        feed_dict[model.keep_prob_pl_list[0]] = 1.0

                        # Compute accuracy
                        ler_train_char, ler_train_phone = sess.run(
                            [ler_op_main, ler_op_sub], feed_dict=feed_dict)

                        duration_step = time.time() - start_time_step
                        print(
                            'Step %d: loss = %.3f / cer = %.3f / per = %.3f (%.3f sec) / lr = %.5f'
                            % (step + 1, loss_train, ler_train_char,
                               ler_train_phone, duration_step, learning_rate))
                        start_time_step = time.time()

                        # Visualize
                        labels_pred_char_st, labels_pred_phone_st = sess.run(
                            [decode_op_main, decode_op_sub],
                            feed_dict=feed_dict)
                        labels_pred_char = sparsetensor2list(
                            labels_pred_char_st, batch_size=batch_size)
                        labels_pred_phone = sparsetensor2list(
                            labels_pred_phone_st, batch_size=batch_size)

                        print('Character')
                        try:
                            print('  Ref: %s' % idx2alpha(labels_char[0]))
                            print('  Hyp: %s' % idx2alpha(labels_pred_char[0]))
                        except IndexError:
                            print('Character')
                            print('  Ref: %s' % idx2alpha(labels_char[0]))
                            print('  Hyp: %s' % '')

                        print('Phone')
                        try:
                            print('  Ref: %s' % idx2phone(labels_phone[0]))
                            print('  Hyp: %s' %
                                  idx2phone(labels_pred_phone[0]))
                        except IndexError:
                            print('  Ref: %s' % idx2phone(labels_phone[0]))
                            print('  Hyp: %s' % '')
                            # NOTE: This is for no prediction
                        print('-' * 30)

                        if ler_train_char < 0.1:
                            print('Modle is Converged.')
                            break

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=step,
                            value=ler_train_char)
                        feed_dict[learning_rate_pl] = learning_rate
Пример #2
0
    def check_training(self,
                       encoder_type,
                       label_type,
                       lstm_impl='LSTMBlockCell',
                       save_params=False):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  label_type: %s' % label_type)
        print('  lstm_impl: %s' % lstm_impl)
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 1
            splice = 11 if encoder_type in [
                'vgg_blstm', 'vgg_lstm', 'vgg_wang', 'resnet_wang', 'cnn_zhang'
            ] else 1
            inputs, labels_true_st, inputs_seq_len = generate_data(
                label_type=label_type,
                model='ctc',
                batch_size=batch_size,
                splice=splice)
            # NOTE: input_size must be even number when using CudnnLSTM

            # Define model graph
            num_classes = 26 if label_type == 'character' else 61
            model = CTC(
                encoder_type=encoder_type,
                input_size=inputs[0].shape[-1] // splice,
                splice=splice,
                num_units=256,
                num_layers=2,
                num_classes=num_classes,
                lstm_impl=lstm_impl,
                parameter_init=0.1,
                clip_grad=5.0,
                clip_activation=50,
                num_proj=256,
                # bottleneck_dim=50,
                bottleneck_dim=None,
                weight_decay=1e-8)

            # Define placeholders
            model.create_placeholders()
            learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

            # Add to the graph each operation
            loss_op, logits = model.compute_loss(
                model.inputs_pl_list[0], model.labels_pl_list[0],
                model.inputs_seq_len_pl_list[0],
                model.keep_prob_input_pl_list[0],
                model.keep_prob_hidden_pl_list[0],
                model.keep_prob_output_pl_list[0])
            train_op = model.train(loss_op,
                                   optimizer='adam',
                                   learning_rate=learning_rate_pl)
            # NOTE: Adam does not run on CudnnLSTM
            decode_op = model.decoder(logits,
                                      model.inputs_seq_len_pl_list[0],
                                      beam_width=20)
            ler_op = model.compute_ler(decode_op, model.labels_pl_list[0])

            # Define learning rate controller
            learning_rate = 1e-3
            lr_controller = Controller(learning_rate_init=learning_rate,
                                       decay_start_epoch=10,
                                       decay_rate=0.98,
                                       decay_patient_epoch=5,
                                       lower_better=True)

            if save_params:
                # Create a saver for writing training checkpoints
                saver = tf.train.Saver(max_to_keep=None)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            if lstm_impl != 'CudnnLSTM':
                parameters_dict, total_parameters = count_total_parameters(
                    tf.trainable_variables())
                for parameter_name in sorted(parameters_dict.keys()):
                    print("%s %d" %
                          (parameter_name, parameters_dict[parameter_name]))
                print("Total %d variables, %s M parameters" %
                      (len(parameters_dict.keys()), "{:,}".format(
                          total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                model.inputs_pl_list[0]: inputs,
                model.labels_pl_list[0]: labels_true_st,
                model.inputs_seq_len_pl_list[0]: inputs_seq_len,
                model.keep_prob_input_pl_list[0]: 1.0,
                model.keep_prob_hidden_pl_list[0]: 1.0,
                model.keep_prob_output_pl_list[0]: 1.0,
                learning_rate_pl: learning_rate
            }

            idx2phone = Idx2phone(map_file_path='./phone61_ctc.txt')

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Wrapper for tfdbg
                # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

                # Train model
                max_steps = 1000
                start_time_global = time.time()
                start_time_step = time.time()
                ler_train_pre = 1
                not_improved_count = 0
                for step in range(max_steps):

                    # Compute loss
                    _, loss_train = sess.run([train_op, loss_op],
                                             feed_dict=feed_dict)

                    # Gradient check
                    # grads = sess.run(model.clipped_grads,
                    #                  feed_dict=feed_dict)
                    # for grad in grads:
                    #     print(np.max(grad))

                    if (step + 1) % 10 == 0:
                        # Change to evaluation mode
                        feed_dict[model.keep_prob_input_pl_list[0]] = 1.0
                        feed_dict[model.keep_prob_hidden_pl_list[0]] = 1.0
                        feed_dict[model.keep_prob_output_pl_list[0]] = 1.0

                        # Compute accuracy
                        ler_train = sess.run(ler_op, feed_dict=feed_dict)

                        duration_step = time.time() - start_time_step
                        print(
                            'Step %d: loss = %.3f / ler = %.3f (%.3f sec) / lr = %.5f'
                            % (step + 1, loss_train, ler_train, duration_step,
                               learning_rate))
                        start_time_step = time.time()

                        # Decode
                        labels_pred_st = sess.run(decode_op,
                                                  feed_dict=feed_dict)
                        labels_true = sparsetensor2list(labels_true_st,
                                                        batch_size=batch_size)

                        # Visualize
                        try:
                            labels_pred = sparsetensor2list(
                                labels_pred_st, batch_size=batch_size)
                            if label_type == 'character':
                                print('Ref: %s' % idx2alpha(labels_true[0]))
                                print('Hyp: %s' % idx2alpha(labels_pred[0]))
                            else:
                                print('Ref: %s' % idx2phone(labels_true[0]))
                                print('Hyp: %s' % idx2phone(labels_pred[0]))

                        except IndexError:
                            if label_type == 'character':
                                print('Ref: %s' % idx2alpha(labels_true[0]))
                                print('Hyp: %s' % '')
                            else:
                                print('Ref: %s' % idx2phone(labels_true[0]))
                                print('Hyp: %s' % '')
                            # NOTE: This is for no prediction

                        if ler_train >= ler_train_pre:
                            not_improved_count += 1
                        else:
                            not_improved_count = 0
                        if ler_train < 0.05:
                            print('Modle is Converged.')
                            if save_params:
                                # Save model (check point)
                                checkpoint_file = './model.ckpt'
                                save_path = saver.save(sess,
                                                       checkpoint_file,
                                                       global_step=1)
                                print("Model saved in file: %s" % save_path)
                            break
                        ler_train_pre = ler_train

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=step,
                            value=ler_train)
                        feed_dict[learning_rate_pl] = learning_rate

                duration_global = time.time() - start_time_global
                print('Total time: %.3f sec' % (duration_global))
    def check(self, encoder_type, lstm_impl, time_major=False):

        print('==================================================')
        print('  encoder_type: %s' % str(encoder_type))
        print('  lstm_impl: %s' % str(lstm_impl))
        print('  time_major: %s' % str(time_major))
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 2
            inputs, labels_char, labels_phone, inputs_seq_len = generate_data(
                label_type='multitask',
                model='ctc',
                batch_size=batch_size)

            # Define model graph
            num_classes_main = 27
            num_classes_sub = 61
            model = MultitaskCTC(
                encoder_type=encoder_type,
                input_size=inputs[0].shape[1],
                num_units=256,
                num_layers_main=2,
                num_layers_sub=1,
                num_classes_main=num_classes_main,
                num_classes_sub=num_classes_sub,
                main_task_weight=0.8,
                lstm_impl=lstm_impl,
                parameter_init=0.1,
                clip_grad_norm=5.0,
                clip_activation=50,
                num_proj=256,
                weight_decay=1e-8,
                # bottleneck_dim=50,
                bottleneck_dim=None,
                time_major=time_major)

            # Define placeholders
            model.create_placeholders()
            learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

            # Add to the graph each operation
            loss_op, logits_main, logits_sub = model.compute_loss(
                model.inputs_pl_list[0],
                model.labels_pl_list[0],
                model.labels_sub_pl_list[0],
                model.inputs_seq_len_pl_list[0],
                model.keep_prob_pl_list[0])
            train_op = model.train(
                loss_op,
                optimizer='adam',
                learning_rate=learning_rate_pl)
            decode_op_main, decode_op_sub = model.decoder(
                logits_main,
                logits_sub,
                model.inputs_seq_len_pl_list[0],
                beam_width=20)
            ler_op_main, ler_op_sub = model.compute_ler(
                decode_op_main, decode_op_sub,
                model.labels_pl_list[0], model.labels_sub_pl_list[0])

            # Define learning rate controller
            learning_rate = 1e-3
            lr_controller = Controller(learning_rate_init=learning_rate,
                                       decay_start_epoch=20,
                                       decay_rate=0.9,
                                       decay_patient_epoch=5,
                                       lower_better=True)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()),
                   "{:,}".format(total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                model.inputs_pl_list[0]: inputs,
                model.labels_pl_list[0]: list2sparsetensor(labels_char, padded_value=-1),
                model.labels_sub_pl_list[0]: list2sparsetensor(labels_phone, padded_value=-1),
                model.inputs_seq_len_pl_list[0]: inputs_seq_len,
                model.keep_prob_pl_list[0]: 0.9,
                learning_rate_pl: learning_rate
            }

            idx2phone = Idx2phone(map_file_path='./phone61.txt')

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Wrapper for tfdbg
                # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

                # Train model
                max_steps = 1000
                start_time_step = time.time()
                for step in range(max_steps):

                    # Compute loss
                    _, loss_train = sess.run(
                        [train_op, loss_op], feed_dict=feed_dict)

                    # Gradient check
                    # grads = sess.run(model.clipped_grads,
                    #                  feed_dict=feed_dict)
                    # for grad in grads:
                    #     print(np.max(grad))

                    if (step + 1) % 10 == 0:
                        # Change to evaluation mode
                        feed_dict[model.keep_prob_pl_list[0]] = 1.0

                        # Compute accuracy
                        ler_train_char, ler_train_phone = sess.run(
                            [ler_op_main, ler_op_sub], feed_dict=feed_dict)

                        duration_step = time.time() - start_time_step
                        print('Step %d: loss = %.3f / cer = %.3f / per = %.3f (%.3f sec) / lr = %.5f' %
                              (step + 1, loss_train, ler_train_char,
                               ler_train_phone, duration_step, learning_rate))
                        start_time_step = time.time()

                        # Visualize
                        labels_pred_char_st, labels_pred_phone_st = sess.run(
                            [decode_op_main, decode_op_sub],
                            feed_dict=feed_dict)
                        labels_pred_char = sparsetensor2list(
                            labels_pred_char_st, batch_size=batch_size)
                        labels_pred_phone = sparsetensor2list(
                            labels_pred_phone_st, batch_size=batch_size)

                        print('Character')
                        try:
                            print('  Ref: %s' % idx2alpha(labels_char[0]))
                            print('  Hyp: %s' % idx2alpha(labels_pred_char[0]))
                        except IndexError:
                            print('Character')
                            print('  Ref: %s' % idx2alpha(labels_char[0]))
                            print('  Hyp: %s' % '')

                        print('Phone')
                        try:
                            print('  Ref: %s' % idx2phone(labels_phone[0]))
                            print('  Hyp: %s' %
                                  idx2phone(labels_pred_phone[0]))
                        except IndexError:
                            print('  Ref: %s' % idx2phone(labels_phone[0]))
                            print('  Hyp: %s' % '')
                            # NOTE: This is for no prediction
                        print('-' * 30)

                        if ler_train_char < 0.1:
                            print('Modle is Converged.')
                            break

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=step,
                            value=ler_train_char)
                        feed_dict[learning_rate_pl] = learning_rate
    def check(self, encoder_type, label_type='character',
              lstm_impl=None, time_major=True, save_params=False):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  label_type: %s' % label_type)
        print('  lstm_impl: %s' % lstm_impl)
        print('  time_major: %s' % str(time_major))
        print('  save_params: %s' % str(save_params))
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 2
            splice = 11 if encoder_type in ['vgg_blstm', 'vgg_lstm', 'cnn_zhang',
                                            'vgg_wang', 'resnet_wang', 'cldnn_wang'] else 1
            num_stack = 2
            inputs, labels, inputs_seq_len = generate_data(
                label_type=label_type,
                model='ctc',
                batch_size=batch_size,
                num_stack=num_stack,
                splice=splice)
            # NOTE: input_size must be even number when using CudnnLSTM

            # Define model graph
            num_classes = 27 if label_type == 'character' else 61
            model = CTC(encoder_type=encoder_type,
                        input_size=inputs[0].shape[-1] // splice // num_stack,
                        splice=splice,
                        num_stack=num_stack,
                        num_units=256,
                        num_layers=2,
                        num_classes=num_classes,
                        lstm_impl=lstm_impl,
                        parameter_init=0.1,
                        clip_grad_norm=5.0,
                        clip_activation=50,
                        num_proj=256,
                        weight_decay=1e-10,
                        # bottleneck_dim=50,
                        bottleneck_dim=None,
                        time_major=time_major)

            # Define placeholders
            model.create_placeholders()
            learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

            # Add to the graph each operation
            loss_op, logits = model.compute_loss(
                model.inputs_pl_list[0],
                model.labels_pl_list[0],
                model.inputs_seq_len_pl_list[0],
                model.keep_prob_pl_list[0])
            train_op = model.train(loss_op,
                                   optimizer='nestrov',
                                   learning_rate=learning_rate_pl)
            # NOTE: Adam does not run on CudnnLSTM
            decode_op = model.decoder(logits,
                                      model.inputs_seq_len_pl_list[0],
                                      beam_width=20)
            ler_op = model.compute_ler(decode_op, model.labels_pl_list[0])

            # Define learning rate controller
            learning_rate = 1e-4
            lr_controller = Controller(learning_rate_init=learning_rate,
                                       decay_start_epoch=50,
                                       decay_rate=0.9,
                                       decay_patient_epoch=10,
                                       lower_better=True)

            if save_params:
                # Create a saver for writing training checkpoints
                saver = tf.train.Saver(max_to_keep=None)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            if lstm_impl != 'CudnnLSTM':
                parameters_dict, total_parameters = count_total_parameters(
                    tf.trainable_variables())
                for parameter_name in sorted(parameters_dict.keys()):
                    print("%s %d" %
                          (parameter_name, parameters_dict[parameter_name]))
                print("Total %d variables, %s M parameters" %
                      (len(parameters_dict.keys()),
                       "{:,}".format(total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                model.inputs_pl_list[0]: inputs,
                model.labels_pl_list[0]: list2sparsetensor(labels, padded_value=-1),
                model.inputs_seq_len_pl_list[0]: inputs_seq_len,
                model.keep_prob_pl_list[0]: 1.0,
                learning_rate_pl: learning_rate
            }

            idx2phone = Idx2phone(map_file_path='./phone61.txt')

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Wrapper for tfdbg
                # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

                # Train model
                max_steps = 1000
                start_time_step = time.time()
                for step in range(max_steps):

                    # for debug
                    # encoder_outputs = sess.run(
                    #     model.encoder_outputs, feed_dict)
                    # print(encoder_outputs.shape)

                    # Compute loss
                    _, loss_train = sess.run(
                        [train_op, loss_op], feed_dict=feed_dict)

                    # Gradient check
                    # grads = sess.run(model.clipped_grads,
                    #                  feed_dict=feed_dict)
                    # for grad in grads:
                    #     print(np.max(grad))

                    if (step + 1) % 10 == 0:
                        # Change to evaluation mode
                        feed_dict[model.keep_prob_pl_list[0]] = 1.0

                        # Compute accuracy
                        ler_train = sess.run(ler_op, feed_dict=feed_dict)

                        duration_step = time.time() - start_time_step
                        print('Step %d: loss = %.3f / ler = %.3f (%.3f sec) / lr = %.5f' %
                              (step + 1, loss_train, ler_train, duration_step, learning_rate))
                        start_time_step = time.time()

                        # Decode
                        labels_pred_st = sess.run(
                            decode_op, feed_dict=feed_dict)

                        # Visualize
                        try:
                            labels_pred = sparsetensor2list(
                                labels_pred_st, batch_size=batch_size)
                            if label_type == 'character':
                                print('Ref: %s' % idx2alpha(labels[0]))
                                print('Hyp: %s' % idx2alpha(labels_pred[0]))
                            else:
                                print('Ref: %s' % idx2phone(labels[0]))
                                print('Hyp: %s' % idx2phone(labels_pred[0]))

                        except IndexError:
                            if label_type == 'character':
                                print('Ref: %s' % idx2alpha(labels[0]))
                                print('Hyp: %s' % '')
                            else:
                                print('Ref: %s' % idx2phone(labels[0]))
                                print('Hyp: %s' % '')
                            # NOTE: This is for no prediction

                        if ler_train < 0.1:
                            print('Modle is Converged.')
                            if save_params:
                                # Save model (check point)
                                checkpoint_file = './model.ckpt'
                                save_path = saver.save(
                                    sess, checkpoint_file, global_step=2)
                                print("Model saved in file: %s" % save_path)
                            break

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=step,
                            value=ler_train)
                        feed_dict[learning_rate_pl] = learning_rate
Пример #5
0
    def check(self,
              rnn_type,
              bidirectional=False,
              label_type='char',
              tie_weights=False):

        print('==================================================')
        print('  label_type: %s' % label_type)
        print('  rnn_type: %s' % rnn_type)
        print('  bidirectional: %s' % str(bidirectional))
        print('  tie_weights: %s' % str(tie_weights))
        print('==================================================')

        # Load batch data
        _, ys, _, y_lens = generate_data(model_type='lm',
                                         label_type=label_type,
                                         batch_size=2)

        if label_type == 'char':
            num_classes = 27
            map_fn = idx2char
        elif label_type == 'word':
            num_classes = 11
            map_fn = idx2word

        # Load model
        model = RNNLM(num_classes,
                      embedding_dim=128,
                      rnn_type=rnn_type,
                      bidirectional=bidirectional,
                      num_units=1024,
                      num_layers=1,
                      dropout_embedding=0.1,
                      dropout_hidden=0.1,
                      dropout_output=0.1,
                      parameter_init_distribution='uniform',
                      parameter_init=0.1,
                      tie_weights=False)

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            print("%s %d" % (name, num_params))
        print("Total %.3f M parameters" % (model.total_parameters / 1000000))

        # Define optimizer
        learning_rate = 1e-3
        model.set_optimizer('adam',
                            learning_rate_init=learning_rate,
                            weight_decay=1e-8,
                            lr_schedule=False,
                            factor=0.1,
                            patience_epoch=5)

        # Define learning rate controller
        lr_controller = Controller(learning_rate_init=learning_rate,
                                   backend='pytorch',
                                   decay_start_epoch=20,
                                   decay_rate=0.9,
                                   decay_patient_epoch=10,
                                   lower_better=True)

        # GPU setting
        model.set_cuda(deterministic=False, benchmark=True)

        # Train model
        max_step = 1000
        start_time_step = time.time()
        for step in range(max_step):

            # Step for parameter update
            model.optimizer.zero_grad()
            loss = model(ys, y_lens)
            loss.backward()
            nn.utils.clip_grad_norm(model.parameters(), 5)
            model.optimizer.step()

            # Inject Gaussian noise to all parameters
            if loss.data[0] < 50:
                model.weight_noise_injection = True

            if (step + 1) % 10 == 0:
                # Compute loss
                loss = model(ys, y_lens, is_eval=True)

                # Compute PPL
                ppl = math.exp(loss)

                # Decode
                # best_hyps, perm_idx = model.decode(
                #     xs, x_lens,
                #     # beam_width=1,
                #     beam_width=2,
                #     max_decode_len=60)

                # Compute accuracy
                # if label_type == 'char':
                #     str_true = map_fn(ys[0, :y_lens[0]][1:-1])
                #     str_pred = map_fn(best_hyps[0][0:-1]).split('>')[0]
                #     ler = compute_cer(ref=str_true.replace('_', ''),
                #                       hyp=str_pred.replace('_', ''),
                #                       normalize=True)
                # elif label_type == 'word':
                #     str_true = map_fn(ys[0, : y_lens[0]][1: -1])
                #     str_pred = map_fn(best_hyps[0][0: -1]).split('>')[0]
                #     ler, _, _, _ = compute_wer(ref=str_true.split('_'),
                #                                hyp=str_pred.split('_'),
                #                                normalize=True)

                duration_step = time.time() - start_time_step
                print('Step %d: loss=%.3f / ppl=%.3f / lr=%.5f (%.3f sec)' %
                      (step + 1, loss, ppl, learning_rate, duration_step))
                start_time_step = time.time()

                # Visualize
                # print('Ref: %s' % str_true)
                # print('Hyp: %s' % str_pred)

                if ppl == 0:
                    print('Modle is Converged.')
                    break

                # Update learning rate
                model.optimizer, learning_rate = lr_controller.decay_lr(
                    optimizer=model.optimizer,
                    learning_rate=learning_rate,
                    epoch=step,
                    value=ppl)
Пример #6
0
    def check_training(self, label_type):
        print('----- ' + label_type + ' -----')
        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 1
            inputs, att_labels, inputs_seq_len, att_labels_seq_len, ctc_labels_st = generate_data(
                label_type=label_type,
                model='joint_ctc_attention',
                batch_size=batch_size)

            # Define model graph
            att_num_classes = 26 + 2 if label_type == 'character' else 61 + 2
            ctc_num_classes = 26 if label_type == 'character' else 61
            # model = load(model_type=model_type)
            network = JointCTCAttention(input_size=inputs[0].shape[1],
                                        encoder_num_unit=256,
                                        encoder_num_layer=2,
                                        attention_dim=128,
                                        attention_type='content',
                                        decoder_num_unit=256,
                                        decoder_num_layer=1,
                                        embedding_dim=20,
                                        att_num_classes=att_num_classes,
                                        ctc_num_classes=ctc_num_classes,
                                        att_task_weight=0.5,
                                        sos_index=att_num_classes - 2,
                                        eos_index=att_num_classes - 1,
                                        max_decode_length=50,
                                        attention_weights_tempareture=1.0,
                                        logits_tempareture=1.0,
                                        parameter_init=0.1,
                                        clip_grad=5.0,
                                        clip_activation_encoder=50,
                                        clip_activation_decoder=50,
                                        dropout_ratio_input=0.9,
                                        dropout_ratio_hidden=0.9,
                                        dropout_ratio_output=1.0,
                                        weight_decay=1e-8,
                                        beam_width=1,
                                        time_major=False)

            # Define placeholders
            network.create_placeholders()
            learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

            # Add to the graph each operation
            loss_op, att_logits, ctc_logits, decoder_outputs_train, decoder_outputs_infer = network.compute_loss(
                network.inputs_pl_list[0], network.att_labels_pl_list[0],
                network.inputs_seq_len_pl_list[0],
                network.att_labels_seq_len_pl_list[0],
                network.ctc_labels_pl_list[0],
                network.keep_prob_input_pl_list[0],
                network.keep_prob_hidden_pl_list[0],
                network.keep_prob_output_pl_list[0])
            train_op = network.train(loss_op,
                                     optimizer='adam',
                                     learning_rate=learning_rate_pl)
            decode_op_train, decode_op_infer = network.decoder(
                decoder_outputs_train, decoder_outputs_infer)
            ler_op = network.compute_ler(network.att_labels_st_true_pl,
                                         network.att_labels_st_pred_pl)

            # Define learning rate controller
            learning_rate = 1e-3
            lr_controller = Controller(learning_rate_init=learning_rate,
                                       decay_start_epoch=10,
                                       decay_rate=0.99,
                                       decay_patient_epoch=5,
                                       lower_better=True)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()), "{:,}".format(
                      total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                network.inputs_pl_list[0]: inputs,
                network.att_labels_pl_list[0]: att_labels,
                network.inputs_seq_len_pl_list[0]: inputs_seq_len,
                network.att_labels_seq_len_pl_list[0]: att_labels_seq_len,
                network.ctc_labels_pl_list[0]: ctc_labels_st,
                network.keep_prob_input_pl_list[0]:
                network.dropout_ratio_input,
                network.keep_prob_hidden_pl_list[0]:
                network.dropout_ratio_hidden,
                network.keep_prob_output_pl_list[0]:
                network.dropout_ratio_output,
                learning_rate_pl: learning_rate
            }

            map_file_path = '../../experiments/timit/metrics/mapping_files/attention/phone61_to_num.txt'

            with tf.Session() as sess:

                # Initialize parameters
                sess.run(init_op)

                # Wrapper for tfdbg
                # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

                # Train model
                max_steps = 400
                start_time_global = time.time()
                start_time_step = time.time()
                ler_train_pre = 1
                not_improved_count = 0
                for step in range(max_steps):

                    # Compute loss
                    _, loss_train = sess.run([train_op, loss_op],
                                             feed_dict=feed_dict)

                    # Gradient check
                    # grads = sess.run(network.clipped_grads,
                    #                  feed_dict=feed_dict)
                    # for grad in grads:
                    #     print(np.max(grad))

                    if (step + 1) % 10 == 0:
                        # Change to evaluation mode
                        feed_dict[network.keep_prob_input_pl_list[0]] = 1.0
                        feed_dict[network.keep_prob_hidden_pl_list[0]] = 1.0
                        feed_dict[network.keep_prob_output_pl_list[0]] = 1.0

                        # Predict class ids
                        predicted_ids_train, predicted_ids_infer = sess.run(
                            [decode_op_train, decode_op_infer],
                            feed_dict=feed_dict)

                        # Compute accuracy
                        try:
                            feed_dict_ler = {
                                network.att_labels_st_true_pl:
                                list2sparsetensor(att_labels, padded_value=27),
                                network.att_labels_st_pred_pl:
                                list2sparsetensor(predicted_ids_infer,
                                                  padded_value=27)
                            }
                            ler_train = sess.run(ler_op,
                                                 feed_dict=feed_dict_ler)
                        except ValueError:
                            ler_train = 1

                        duration_step = time.time() - start_time_step
                        print(
                            'Step %d: loss = %.3f / ler = %.4f (%.3f sec) / lr = %.5f'
                            % (step + 1, loss_train, ler_train, duration_step,
                               learning_rate))
                        start_time_step = time.time()

                        # Visualize
                        if label_type == 'character':
                            print('True            : %s' %
                                  idx2alpha(att_labels[0]))
                            print('Pred (Training) : <%s' %
                                  idx2alpha(predicted_ids_train[0]))
                            print('Pred (Inference): <%s' %
                                  idx2alpha(predicted_ids_infer[0]))
                        else:
                            print('True            : %s' %
                                  idx2phone(att_labels[0], map_file_path))
                            print('Pred (Training) : < %s' % idx2phone(
                                predicted_ids_train[0], map_file_path))
                            print('Pred (Inference): < %s' % idx2phone(
                                predicted_ids_infer[0], map_file_path))

                        if ler_train >= ler_train_pre:
                            not_improved_count += 1
                        else:
                            not_improved_count = 0
                        if not_improved_count >= 10:
                            print('Model is Converged.')
                            break
                        ler_train_pre = ler_train

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=step,
                            value=ler_train)
                        feed_dict[learning_rate_pl] = learning_rate

                duration_global = time.time() - start_time_global
                print('Total time: %.3f sec' % (duration_global))
    def check(self, encoder_type, lstm_impl=None, time_major=False):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  lstm_impl: %s' % lstm_impl)
        print('  time_major: %s' % time_major)
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 4
            splice = 5 if encoder_type in ['vgg_blstm', 'vgg_lstm',
                                           'vgg_wang', 'resnet_wang', 'cldnn_wang',
                                           'cnn_zhang'] else 1
            num_stack = 2
            inputs, _, inputs_seq_len = generate_data(
                label_type='character',
                model='ctc',
                batch_size=batch_size,
                num_stack=num_stack,
                splice=splice)
            frame_num, input_size = inputs[0].shape

            # Define model graph
            if encoder_type in ['blstm', 'lstm']:
                encoder = load(encoder_type)(
                    num_units=256,
                    num_proj=None,
                    num_layers=5,
                    lstm_impl=lstm_impl,
                    use_peephole=True,
                    parameter_init=0.1,
                    clip_activation=5,
                    time_major=time_major)
            elif encoder_type in ['bgru', 'gru']:
                encoder = load(encoder_type)(
                    num_units=256,
                    num_layers=5,
                    parameter_init=0.1,
                    time_major=time_major)
            elif encoder_type in ['vgg_blstm', 'vgg_lstm', 'cldnn_wang']:
                encoder = load(encoder_type)(
                    input_size=input_size // splice // num_stack,
                    splice=splice,
                    num_stack=num_stack,
                    num_units=256,
                    num_proj=None,
                    num_layers=5,
                    lstm_impl=lstm_impl,
                    use_peephole=True,
                    parameter_init=0.1,
                    clip_activation=5,
                    time_major=time_major)
            elif encoder_type in ['multitask_blstm', 'multitask_lstm']:
                encoder = load(encoder_type)(
                    num_units=256,
                    num_proj=None,
                    num_layers_main=5,
                    num_layers_sub=3,
                    lstm_impl=lstm_impl,
                    use_peephole=True,
                    parameter_init=0.1,
                    clip_activation=5,
                    time_major=time_major)
            elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                encoder = load(encoder_type)(
                    input_size=input_size // splice // num_stack,
                    splice=splice,
                    num_stack=num_stack,
                    parameter_init=0.1,
                    time_major=time_major)
                # NOTE: topology is pre-defined
            else:
                raise NotImplementedError

            # Create placeholders
            inputs_pl = tf.placeholder(tf.float32,
                                       shape=[None, None, input_size],
                                       name='inputs')
            inputs_seq_len_pl = tf.placeholder(tf.int32,
                                               shape=[None],
                                               name='inputs_seq_len')
            keep_prob_pl = tf.placeholder(tf.float32, name='keep_prob')

            # operation for forward computation
            if encoder_type in ['multitask_blstm', 'multitask_lstm']:
                hidden_states_op, final_state_op, hidden_states_sub_op, final_state_sub_op = encoder(
                    inputs=inputs_pl,
                    inputs_seq_len=inputs_seq_len_pl,
                    keep_prob=keep_prob_pl,
                    is_training=True)
            else:
                hidden_states_op, final_state_op = encoder(
                    inputs=inputs_pl,
                    inputs_seq_len=inputs_seq_len_pl,
                    keep_prob=keep_prob_pl,
                    is_training=True)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()),
                   "{:,}".format(total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                inputs_pl: inputs,
                inputs_seq_len_pl: inputs_seq_len,
                keep_prob_pl: 0.9
            }

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Make prediction
                if encoder_type in ['multitask_blstm', 'multitask_lstm']:
                    encoder_outputs, final_state, hidden_states_sub, final_state_sub = sess.run(
                        [hidden_states_op, final_state_op,
                         hidden_states_sub_op, final_state_sub_op],
                        feed_dict=feed_dict)
                elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                    encoder_outputs = sess.run(
                        hidden_states_op, feed_dict=feed_dict)
                else:
                    encoder_outputs, final_state = sess.run(
                        [hidden_states_op, final_state_op],
                        feed_dict=feed_dict)

                # Convert always to batch-major
                if time_major:
                    encoder_outputs = encoder_outputs.transpose(1, 0, 2)

                if encoder_type in ['blstm', 'bgru', 'vgg_blstm', 'multitask_blstm', 'cldnn_wang']:
                    if encoder_type != 'cldnn_wang':
                        self.assertEqual(
                            (batch_size, frame_num, encoder.num_units * 2), encoder_outputs.shape)

                    if encoder_type != 'bgru':
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].c.shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].h.shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[1].c.shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[1].h.shape)

                        if encoder_type == 'multitask_blstm':
                            self.assertEqual(
                                (batch_size, frame_num, encoder.num_units * 2), hidden_states_sub.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[0].c.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[0].h.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[1].c.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[1].h.shape)
                    else:
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[1].shape)

                elif encoder_type in ['lstm', 'gru', 'vgg_lstm', 'multitask_lstm']:
                    self.assertEqual(
                        (batch_size, frame_num, encoder.num_units), encoder_outputs.shape)

                    if encoder_type != 'gru':
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].c.shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].h.shape)

                        if encoder_type == 'multitask_lstm':
                            self.assertEqual(
                                (batch_size, frame_num, encoder.num_units), hidden_states_sub.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[0].c.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[0].h.shape)
                    else:
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].shape)

                elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                    self.assertEqual(3, len(encoder_outputs.shape))
                    self.assertEqual(
                        (batch_size, frame_num), encoder_outputs.shape[:2])
    def check_encode(self, encoder_type, lstm_impl=None):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  lstm_impl: %s' % lstm_impl)
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 4
            splice = 11 if encoder_type in ['vgg_blstm', 'vgg_lstm', 'vgg_wang',
                                            'resnet_wang', 'cnn_zhang'] else 1
            inputs, _, inputs_seq_len = generate_data(
                label_type='character',
                model='ctc',
                batch_size=batch_size,
                splice=splice)
            frame_num, input_size = inputs[0].shape

            # Define model graph
            if encoder_type in ['blstm', 'lstm']:
                encoder = load(encoder_type)(
                    num_units=256,
                    num_layers=5,
                    num_classes=0,  # return hidden states
                    lstm_impl=lstm_impl,
                    parameter_init=0.1)
            elif encoder_type in ['bgru', 'gru']:
                encoder = load(encoder_type)(
                    num_units=256,
                    num_layers=5,
                    num_classes=0,  # return hidden states
                    parameter_init=0.1)
            elif encoder_type in ['vgg_blstm', 'vgg_lstm']:
                encoder = load(encoder_type)(
                    input_size=input_size // 11,
                    splice=11,
                    num_units=256,
                    num_layers=5,
                    num_classes=0,  # return hidden states
                    lstm_impl=lstm_impl,
                    parameter_init=0.1)
            elif encoder_type in ['multitask_blstm', 'multitask_lstm']:
                encoder = load(encoder_type)(
                    num_units=256,
                    num_layers_main=5,
                    num_layers_sub=3,
                    num_classes_main=0,  # return hidden states
                    num_classes_sub=0,  # return hidden states
                    lstm_impl=lstm_impl,
                    parameter_init=0.1)
            elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                encoder = load(encoder_type)(
                    input_size=input_size // 11,
                    splice=11,
                    num_classes=27,
                    parameter_init=0.1)
                # NOTE: topology is pre-defined
            else:
                raise NotImplementedError

            # Create placeholders
            inputs_pl = tf.placeholder(tf.float32,
                                       shape=[None, None, input_size],
                                       name='inputs')
            inputs_seq_len_pl = tf.placeholder(tf.int32,
                                               shape=[None],
                                               name='inputs_seq_len')
            keep_prob_input_pl = tf.placeholder(tf.float32,
                                                name='keep_prob_input')
            keep_prob_hidden_pl = tf.placeholder(tf.float32,
                                                 name='keep_prob_hidden')
            keep_prob_output_pl = tf.placeholder(tf.float32,
                                                 name='keep_prob_output')

            # operation for forward computation
            if encoder_type in ['multitask_blstm', 'multitask_lstm']:
                hidden_states_op, final_state_op, hidden_states_sub_op, final_state_sub_op = encoder(
                    inputs=inputs_pl,
                    inputs_seq_len=inputs_seq_len_pl,
                    keep_prob_input=keep_prob_input_pl,
                    keep_prob_hidden=keep_prob_hidden_pl,
                    keep_prob_output=keep_prob_output_pl)
            else:
                hidden_states_op, final_state_op = encoder(
                    inputs=inputs_pl,
                    inputs_seq_len=inputs_seq_len_pl,
                    keep_prob_input=keep_prob_input_pl,
                    keep_prob_hidden=keep_prob_hidden_pl,
                    keep_prob_output=keep_prob_output_pl)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()),
                   "{:,}".format(total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                inputs_pl: inputs,
                inputs_seq_len_pl: inputs_seq_len,
                keep_prob_input_pl: 0.9,
                keep_prob_hidden_pl: 0.9,
                keep_prob_output_pl: 1.0
            }

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Make prediction
                if encoder_type in ['multitask_blstm', 'multitask_lstm']:
                    hidden_states, final_state, hidden_states_sub, final_state_sub = sess.run(
                        [hidden_states_op, final_state_op, hidden_states_sub_op, final_state_sub_op], feed_dict=feed_dict)
                elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                    hidden_states = sess.run(
                        hidden_states_op, feed_dict=feed_dict)
                else:
                    hidden_states, final_state = sess.run(
                        [hidden_states_op, final_state_op], feed_dict=feed_dict)

                if encoder_type in ['blstm', 'bgru', 'vgg_blstm', 'multitask_blstm']:
                    self.assertEqual(
                        (batch_size, frame_num, encoder.num_units * 2), hidden_states.shape)

                    if encoder_type in ['blstm', 'vgg_blstm', 'multitask_blstm']:
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].c.shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].h.shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[1].c.shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[1].h.shape)

                        if encoder_type == 'multitask_blstm':
                            self.assertEqual(
                                (batch_size, frame_num, encoder.num_units * 2), hidden_states_sub.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[0].c.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[0].h.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[1].c.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[1].h.shape)
                    else:
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[1].shape)

                elif encoder_type in ['lstm', 'gru', 'vgg_lstm']:
                    self.assertEqual(
                        (batch_size, frame_num, encoder.num_units), hidden_states.shape)

                    if encoder_type in ['lstm', 'vgg_lstm', 'multitask_lstm']:
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].c.shape)
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].h.shape)

                        if encoder_type == 'multitask_lstm':
                            self.assertEqual(
                                (batch_size, frame_num, encoder.num_units), hidden_states_sub.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[0].c.shape)
                            self.assertEqual(
                                (batch_size, encoder.num_units), final_state_sub[0].h.shape)
                    else:
                        self.assertEqual(
                            (batch_size, encoder.num_units), final_state[0].shape)

                elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                    self.assertEqual(
                        (frame_num, batch_size, encoder.num_classes), hidden_states.shape)
Пример #9
0
    def check_training(self):

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 4
            inputs, labels_true_char_st, labels_true_phone_st, inputs_seq_len = generate_data(
                label_type='multitask', model='ctc', batch_size=batch_size)

            # Define placeholders
            inputs_pl = tf.placeholder(tf.float32,
                                       shape=[None, None, inputs.shape[-1]],
                                       name='inputs')
            indices_pl = tf.placeholder(tf.int64, name='indices')
            values_pl = tf.placeholder(tf.int32, name='values')
            shape_pl = tf.placeholder(tf.int64, name='shape')
            labels_pl = tf.SparseTensor(indices_pl, values_pl, shape_pl)
            indices_sub_pl = tf.placeholder(tf.int64, name='indices_sub')
            values_sub_pl = tf.placeholder(tf.int32, name='values_sub')
            shape_sub_pl = tf.placeholder(tf.int64, name='shape_sub')
            labels_sub_pl = tf.SparseTensor(indices_sub_pl, values_sub_pl,
                                            shape_sub_pl)
            inputs_seq_len_pl = tf.placeholder(tf.int64,
                                               shape=[None],
                                               name='inputs_seq_len')
            keep_prob_input_pl = tf.placeholder(tf.float32,
                                                name='keep_prob_input')
            keep_prob_hidden_pl = tf.placeholder(tf.float32,
                                                 name='keep_prob_hidden')

            # Define model graph
            num_classes_main = 26
            num_classes_sub = 61
            network = Multitask_BLSTM_CTC(batch_size=batch_size,
                                          input_size=inputs[0].shape[1],
                                          num_unit=256,
                                          num_layer_main=2,
                                          num_layer_sub=1,
                                          num_classes_main=num_classes_main,
                                          num_classes_sub=num_classes_sub,
                                          main_task_weight=0.8,
                                          parameter_init=0.1,
                                          clip_grad=5.0,
                                          clip_activation=50,
                                          dropout_ratio_input=1.0,
                                          dropout_ratio_hidden=1.0,
                                          num_proj=None,
                                          weight_decay=1e-8)

            # Add to the graph each operation
            loss_op, logits_main, logits_sub = network.compute_loss(
                inputs_pl, labels_pl, labels_sub_pl, inputs_seq_len_pl,
                keep_prob_input_pl, keep_prob_hidden_pl)
            learning_rate = 1e-3
            train_op = network.train(loss_op,
                                     optimizer='rmsprop',
                                     learning_rate_init=learning_rate,
                                     is_scheduled=False)
            decode_op_main, decode_op_sub = network.decoder(
                logits_main,
                logits_sub,
                inputs_seq_len_pl,
                decode_type='beam_search',
                beam_width=20)
            ler_op_main, ler_op_sub = network.compute_ler(
                decode_op_main, decode_op_sub, labels_pl, labels_sub_pl)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()), "{:,}".format(
                      total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                inputs_pl: inputs,
                labels_pl: labels_true_char_st,
                labels_sub_pl: labels_true_phone_st,
                inputs_seq_len_pl: inputs_seq_len,
                keep_prob_input_pl: network.dropout_ratio_input,
                keep_prob_hidden_pl: network.dropout_ratio_hidden,
                network.lr: learning_rate
            }

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Wrapper for tfdbg
                # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

                # Train model
                max_steps = 400
                start_time_global = time.time()
                start_time_step = time.time()
                ler_train_char_pre = 1
                not_improved_count = 0
                for step in range(max_steps):

                    # Compute loss
                    _, loss_train = sess.run([train_op, loss_op],
                                             feed_dict=feed_dict)

                    # Gradient check
                    # grads = sess.run(network.clipped_grads, feed_dict=feed_dict)
                    # for grad in grads:
                    #     print(np.max(grad))

                    if (step + 1) % 10 == 0:
                        # Change to evaluation mode
                        feed_dict[keep_prob_input_pl] = 1.0
                        feed_dict[keep_prob_hidden_pl] = 1.0

                        # Compute accuracy
                        ler_train_char, ler_train_phone = sess.run(
                            [ler_op_main, ler_op_sub], feed_dict=feed_dict)

                        duration_step = time.time() - start_time_step
                        print(
                            'Step %d: loss = %.3f / cer = %.4f / per = %.4f (%.3f sec)\n'
                            % (step + 1, loss_train, ler_train_char,
                               ler_train_phone, duration_step))
                        start_time_step = time.time()

                        # Visualize
                        labels_pred_char_st, labels_pred_phone_st = sess.run(
                            [decode_op_main, decode_op_sub],
                            feed_dict=feed_dict)
                        labels_true_char = sparsetensor2list(
                            labels_true_char_st, batch_size=batch_size)
                        labels_true_phone = sparsetensor2list(
                            labels_true_phone_st, batch_size=batch_size)
                        labels_pred_char = sparsetensor2list(
                            labels_pred_char_st, batch_size=batch_size)
                        labels_pred_phone = sparsetensor2list(
                            labels_pred_phone_st, batch_size=batch_size)

                        # character
                        print('Character')
                        print('  True: %s' % num2alpha(labels_true_char[0]))
                        print('  Pred: %s' % num2alpha(labels_pred_char[0]))
                        print('Phone')
                        print('  True: %s' % num2phone(labels_true_phone[0]))
                        print('  Pred: %s' % num2phone(labels_pred_phone[0]))
                        print('----------------------------------------')

                        if ler_train_char >= ler_train_char_pre:
                            not_improved_count += 1
                        else:
                            not_improved_count = 0
                        if not_improved_count >= 5:
                            print('Modle is Converged.')
                            break
                        ler_train_char_pre = ler_train_char

                        # Change to training mode
                        network.is_training = True

                duration_global = time.time() - start_time_global
                print('Total time: %.3f sec' % (duration_global))
    def check(self, decoder_type):

        print('==================================================')
        print('  decoder_type: %s' % decoder_type)
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 2
            num_stack = 2
            inputs, labels, inputs_seq_len = generate_data(
                label_type='character',
                model='ctc',
                batch_size=batch_size,
                num_stack=num_stack,
                splice=1)
            max_time = inputs.shape[1]

            # Define model graph
            model = CTC(encoder_type='blstm',
                        input_size=inputs[0].shape[-1],
                        splice=1,
                        num_stack=num_stack,
                        num_units=256,
                        num_layers=2,
                        num_classes=27,
                        lstm_impl='LSTMBlockCell',
                        parameter_init=0.1,
                        clip_grad_norm=5.0,
                        clip_activation=50,
                        num_proj=256,
                        weight_decay=1e-6)

            # Define placeholders
            model.create_placeholders()

            # Add to the graph each operation
            _, logits = model.compute_loss(model.inputs_pl_list[0],
                                           model.labels_pl_list[0],
                                           model.inputs_seq_len_pl_list[0],
                                           model.keep_prob_pl_list[0])
            beam_width = 20 if 'beam_search' in decoder_type else 1
            decode_op = model.decoder(logits,
                                      model.inputs_seq_len_pl_list[0],
                                      beam_width=beam_width)
            ler_op = model.compute_ler(decode_op, model.labels_pl_list[0])
            posteriors_op = model.posteriors(logits, blank_prior=1)

            if decoder_type == 'np_greedy':
                decoder = GreedyDecoder(blank_index=model.num_classes)
            elif decoder_type == 'np_beam_search':
                decoder = BeamSearchDecoder(space_index=26,
                                            blank_index=model.num_classes - 1)

            # Make feed dict
            feed_dict = {
                model.inputs_pl_list[0]: inputs,
                model.labels_pl_list[0]: list2sparsetensor(labels,
                                                           padded_value=-1),
                model.inputs_seq_len_pl_list[0]: inputs_seq_len,
                model.keep_prob_pl_list[0]: 1.0
            }

            # Create a saver for writing training checkpoints
            saver = tf.train.Saver()

            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state('./')

                # If check point exists
                if ckpt:
                    model_path = ckpt.model_checkpoint_path
                    saver.restore(sess, model_path)
                    print("Model restored: " + model_path)
                else:
                    raise ValueError('There are not any checkpoints.')

                if decoder_type in ['tf_greedy', 'tf_beam_search']:
                    # Decode
                    labels_pred_st = sess.run(decode_op, feed_dict=feed_dict)
                    labels_pred = sparsetensor2list(labels_pred_st,
                                                    batch_size=batch_size)

                    # Compute accuracy
                    cer = sess.run(ler_op, feed_dict=feed_dict)
                else:
                    # Compute CTC posteriors
                    probs = sess.run(posteriors_op, feed_dict=feed_dict)
                    probs = probs.reshape(-1, max_time, model.num_classes)

                    if decoder_type == 'np_greedy':
                        # Decode
                        labels_pred = decoder(probs=probs,
                                              seq_len=inputs_seq_len)

                    elif decoder_type == 'np_beam_search':
                        # Decode
                        labels_pred, scores = decoder(probs=probs,
                                                      seq_len=inputs_seq_len,
                                                      beam_width=beam_width)

                    # Compute accuracy
                    cer = compute_cer(str_pred=idx2alpha(labels_pred[0]),
                                      str_true=idx2alpha(labels[0]),
                                      normalize=True)

                # Visualize
                print('CER: %.3f %%' % (cer * 100))
                print('Ref: %s' % idx2alpha(labels[0]))
                print('Hyp: %s' % idx2alpha(labels_pred[0]))
Пример #11
0
    def check(self,
              encoder_type,
              decoder_type,
              bidirectional=False,
              attention_type='location',
              subsample=False,
              projection=False,
              ctc_loss_weight_sub=0,
              conv=False,
              batch_norm=False,
              residual=False,
              dense_residual=False,
              num_heads=1,
              backward_sub=False):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  bidirectional: %s' % str(bidirectional))
        print('  projection: %s' % str(projection))
        print('  decoder_type: %s' % decoder_type)
        print('  attention_type: %s' % attention_type)
        print('  subsample: %s' % str(subsample))
        print('  ctc_loss_weight_sub: %s' % str(ctc_loss_weight_sub))
        print('  conv: %s' % str(conv))
        print('  batch_norm: %s' % str(batch_norm))
        print('  residual: %s' % str(residual))
        print('  dense_residual: %s' % str(dense_residual))
        print('  backward_sub: %s' % str(backward_sub))
        print('  num_heads: %s' % str(num_heads))
        print('==================================================')

        if conv or encoder_type == 'cnn':
            # pattern 1
            # conv_channels = [32, 32]
            # conv_kernel_sizes = [[41, 11], [21, 11]]
            # conv_strides = [[2, 2], [2, 1]]
            # poolings = [[], []]

            # pattern 2 (VGG like)
            conv_channels = [64, 64]
            conv_kernel_sizes = [[3, 3], [3, 3]]
            conv_strides = [[1, 1], [1, 1]]
            poolings = [[2, 2], [2, 2]]
        else:
            conv_channels = []
            conv_kernel_sizes = []
            conv_strides = []
            poolings = []

        # Load batch data
        splice = 1
        num_stack = 1 if subsample or conv or encoder_type == 'cnn' else 2
        xs, ys, ys_sub, x_lens, y_lens, y_lens_sub = generate_data(
            label_type='word_char',
            batch_size=2,
            num_stack=num_stack,
            splice=splice,
            backend='chainer')

        num_classes = 11
        num_classes_sub = 27

        # Load model
        model = HierarchicalAttentionSeq2seq(
            input_size=xs[0].shape[-1] // splice // num_stack,  # 120
            encoder_type=encoder_type,
            encoder_bidirectional=bidirectional,
            encoder_num_units=320,
            encoder_num_proj=320 if projection else 0,
            encoder_num_layers=2,
            encoder_num_layers_sub=1,
            attention_type=attention_type,
            attention_dim=128,
            decoder_type=decoder_type,
            decoder_num_units=320,
            decoder_num_layers=1,
            decoder_num_units_sub=320,
            decoder_num_layers_sub=1,
            embedding_dim=64,
            embedding_dim_sub=32,
            dropout_input=0.1,
            dropout_encoder=0.1,
            dropout_decoder=0.1,
            dropout_embedding=0.1,
            main_loss_weight=0.8,
            sub_loss_weight=0.2 if ctc_loss_weight_sub == 0 else 0,
            num_classes=num_classes,
            num_classes_sub=num_classes_sub,
            parameter_init_distribution='uniform',
            parameter_init=0.1,
            recurrent_weight_orthogonal=False,
            init_forget_gate_bias_with_one=True,
            subsample_list=[] if not subsample else [True, False],
            subsample_type='drop' if not subsample else subsample,
            bridge_layer=True,
            init_dec_state='first',
            sharpening_factor=1,
            logits_temperature=1,
            sigmoid_smoothing=False,
            ctc_loss_weight_sub=ctc_loss_weight_sub,
            attention_conv_num_channels=10,
            attention_conv_width=201,
            input_channel=3,
            num_stack=num_stack,
            splice=splice,
            conv_channels=conv_channels,
            conv_kernel_sizes=conv_kernel_sizes,
            conv_strides=conv_strides,
            poolings=poolings,
            activation='relu',
            batch_norm=batch_norm,
            scheduled_sampling_prob=0.1,
            scheduled_sampling_max_step=200,
            label_smoothing_prob=0.1,
            weight_noise_std=0,
            encoder_residual=residual,
            encoder_dense_residual=dense_residual,
            decoder_residual=residual,
            decoder_dense_residual=dense_residual,
            decoding_order='attend_generate_update',
            bottleneck_dim=256,
            bottleneck_dim_sub=256,
            backward_sub=backward_sub,
            num_heads=num_heads,
            num_heads_sub=num_heads)

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            print("%s %d" % (name, num_params))
        print("Total %.3f M parameters" % (model.total_parameters / 1000000))

        # Define optimizer
        learning_rate = 1e-3
        model.set_optimizer('adam',
                            learning_rate_init=learning_rate,
                            weight_decay=1e-6,
                            lr_schedule=False,
                            factor=0.1,
                            patience_epoch=5)

        # Define learning rate controller
        lr_controller = Controller(learning_rate_init=learning_rate,
                                   backend='chainer',
                                   decay_start_epoch=20,
                                   decay_rate=0.9,
                                   decay_patient_epoch=10,
                                   lower_better=True)

        # GPU setting
        model.set_cuda(deterministic=False, benchmark=True)

        # Train model
        max_step = 300
        start_time_step = time.time()
        for step in range(max_step):

            # Step for parameter update
            loss, loss_main, loss_sub = model(xs, ys, x_lens, y_lens, ys_sub,
                                              y_lens_sub)
            model.optimizer.target.cleargrads()
            model.cleargrads()
            loss.backward()
            loss.unchain_backward()
            model.optimizer.update()

            if (step + 1) % 10 == 0:
                # Compute loss
                loss, loss_main, loss_sub = model(xs,
                                                  ys,
                                                  x_lens,
                                                  y_lens,
                                                  ys_sub,
                                                  y_lens_sub,
                                                  is_eval=True)

                # Decode
                best_hyps, _, _ = model.decode(
                    xs,
                    x_lens,
                    beam_width=1,
                    # beam_width=2,
                    max_decode_len=30)
                best_hyps_sub, _, _ = model.decode(
                    xs,
                    x_lens,
                    beam_width=1,
                    # beam_width=2,
                    max_decode_len=60,
                    task_index=1)

                str_hyp = idx2word(best_hyps[0][:-1]).split('>')[0]
                str_ref = idx2word(ys[0])
                str_hyp_sub = idx2char(best_hyps_sub[0][:-1]).split('>')[0]
                str_ref_sub = idx2char(ys_sub[0])

                # Compute accuracy
                try:
                    wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                               hyp=str_hyp.split('_'),
                                               normalize=True)
                    cer, _, _, _ = compute_wer(
                        ref=list(str_ref_sub.replace('_', '')),
                        hyp=list(str_hyp_sub.replace('_', '')),
                        normalize=True)
                except:
                    wer = 1
                    cer = 1

                duration_step = time.time() - start_time_step
                print(
                    'Step %d: loss=%.3f(%.3f/%.3f) / wer=%.3f / cer=%.3f / lr=%.5f (%.3f sec)'
                    % (step + 1, loss, loss_main, loss_sub, wer, cer,
                       learning_rate, duration_step))
                start_time_step = time.time()

                # Visualize
                print('Ref: %s' % str_ref)
                print('Hyp (word): %s' % str_hyp)
                print('Hyp (char): %s' % str_hyp_sub)

                if cer < 0.1:
                    print('Modle is Converged.')
                    break

                # Update learning rate
                model.optimizer, learning_rate = lr_controller.decay_lr(
                    optimizer=model.optimizer,
                    learning_rate=learning_rate,
                    epoch=step,
                    value=wer)
    def check_training(self, attention_type, label_type):

        print('----- attention_type: ' + attention_type + ', label_type: ' +
              label_type + ' -----')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 1
            inputs, labels, inputs_seq_len, labels_seq_len = generate_data(
                label_type=label_type,
                model='attention',
                batch_size=batch_size)

            # Define placeholders
            inputs_pl = tf.placeholder(
                tf.float32,
                shape=[batch_size, None, inputs.shape[-1]],
                name='inputs')

            # `[batch_size, max_time]`
            labels_pl = tf.placeholder(tf.int32,
                                       shape=[None, None],
                                       name='labels')

            # These are prepared for computing LER
            indices_true_pl = tf.placeholder(tf.int64, name='indices')
            values_true_pl = tf.placeholder(tf.int32, name='values')
            shape_true_pl = tf.placeholder(tf.int64, name='shape')
            labels_st_true_pl = tf.SparseTensor(indices_true_pl,
                                                values_true_pl, shape_true_pl)
            indices_pred_pl = tf.placeholder(tf.int64, name='indices')
            values_pred_pl = tf.placeholder(tf.int32, name='values')
            shape_pred_pl = tf.placeholder(tf.int64, name='shape')
            labels_st_pred_pl = tf.SparseTensor(indices_pred_pl,
                                                values_pred_pl, shape_pred_pl)
            inputs_seq_len_pl = tf.placeholder(tf.int32,
                                               shape=[None],
                                               name='inputs_seq_len')
            labels_seq_len_pl = tf.placeholder(tf.int32,
                                               shape=[None],
                                               name='labels_seq_len')
            keep_prob_input_pl = tf.placeholder(tf.float32,
                                                name='keep_prob_input')
            keep_prob_hidden_pl = tf.placeholder(tf.float32,
                                                 name='keep_prob_hidden')

            # Define model graph
            num_classes = 26 + 2 if label_type == 'character' else 61 + 2
            # model = load(model_type=model_type)
            network = BLSTMAttetion(
                batch_size=batch_size,
                input_size=inputs[0].shape[1],
                encoder_num_unit=256,
                encoder_num_layer=2,
                attention_dim=128,
                attention_type=attention_type,
                decoder_num_unit=256,
                decoder_num_layer=1,
                embedding_dim=20,
                num_classes=num_classes,
                sos_index=num_classes - 2,
                eos_index=num_classes - 1,
                max_decode_length=50,
                # attention_smoothing=True,
                attention_weights_tempareture=0.5,
                logits_tempareture=1.0,
                parameter_init=0.1,
                clip_grad=5.0,
                clip_activation_encoder=50,
                clip_activation_decoder=50,
                dropout_ratio_input=1.0,
                dropout_ratio_hidden=1.0,
                weight_decay=0,
                beam_width=0,
                time_major=False)

            # Add to the graph each operation
            loss_op, logits, decoder_outputs_train, decoder_outputs_infer = network.compute_loss(
                inputs_pl, labels_pl, inputs_seq_len_pl, labels_seq_len_pl,
                keep_prob_input_pl, keep_prob_hidden_pl)
            learning_rate = 1e-3
            train_op = network.train(loss_op,
                                     optimizer='rmsprop',
                                     learning_rate_init=learning_rate,
                                     is_scheduled=False)
            decode_op_train, decode_op_infer = network.decoder(
                decoder_outputs_train,
                decoder_outputs_infer,
                decode_type='greedy',
                beam_width=1)
            ler_op = network.compute_ler(labels_st_true_pl, labels_st_pred_pl)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()), "{:,}".format(
                      total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                inputs_pl: inputs,
                labels_pl: labels,
                inputs_seq_len_pl: inputs_seq_len,
                labels_seq_len_pl: labels_seq_len,
                keep_prob_input_pl: network.dropout_ratio_input,
                keep_prob_hidden_pl: network.dropout_ratio_hidden,
                network.lr: learning_rate
            }

            with tf.Session() as sess:

                # Initialize parameters
                sess.run(init_op)

                # Wrapper for tfdbg
                # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

                # Train model
                max_steps = 400
                start_time_global = time.time()
                start_time_step = time.time()
                ler_train_pre = 1
                not_improved_count = 0
                for step in range(max_steps):

                    # Compute loss
                    _, loss_train = sess.run([train_op, loss_op],
                                             feed_dict=feed_dict)

                    # Gradient check
                    # grads = sess.run(network.clipped_grads,
                    #                  feed_dict=feed_dict)
                    # for grad in grads:
                    #     print(np.max(grad))

                    if (step + 1) % 10 == 0:
                        # Change to evaluation mode
                        feed_dict[keep_prob_input_pl] = 1.0
                        feed_dict[keep_prob_hidden_pl] = 1.0

                        # Predict class ids
                        predicted_ids_train, predicted_ids_infer = sess.run(
                            [decode_op_train, decode_op_infer],
                            feed_dict=feed_dict)

                        # Compute accuracy
                        try:
                            feed_dict_ler = {
                                labels_st_true_pl:
                                list2sparsetensor(labels, padded_value=0),
                                labels_st_pred_pl:
                                list2sparsetensor(predicted_ids_infer,
                                                  padded_value=0)
                            }
                            ler_train = sess.run(ler_op,
                                                 feed_dict=feed_dict_ler)
                        except ValueError:
                            ler_train = 1

                        duration_step = time.time() - start_time_step
                        print('Step %d: loss = %.3f / ler = %.4f (%.3f sec)' %
                              (step + 1, loss_train, ler_train, duration_step))
                        start_time_step = time.time()

                        # Visualize
                        if label_type == 'character':
                            print('True            : %s' %
                                  num2alpha(labels[0]))
                            print('Pred (Training) : <%s' %
                                  num2alpha(predicted_ids_train[0]))
                            print('Pred (Inference): <%s' %
                                  num2alpha(predicted_ids_infer[0]))
                        else:
                            print('True            : %s' %
                                  num2phone(labels[0]))
                            print('Pred (Training) : < %s' %
                                  num2phone(predicted_ids_train[0]))
                            print('Pred (Inference): < %s' %
                                  num2phone(predicted_ids_infer[0]))

                        if ler_train >= ler_train_pre:
                            not_improved_count += 1
                        else:
                            not_improved_count = 0
                        if not_improved_count >= 10:
                            print('Model is Converged.')
                            break
                        ler_train_pre = ler_train

                duration_global = time.time() - start_time_global
                print('Total time: %.3f sec' % (duration_global))
Пример #13
0
    def check(self, usage_dec_sub='all', att_reg_weight=1,
              main_loss_weight=0.5, ctc_loss_weight_sub=0,
              dec_attend_temperature=1,
              dec_sigmoid_smoothing=False,
              backward_sub=False, num_heads=1, second_pass=False,
              relax_context_vec_dec=False):

        print('==================================================')
        print('  usage_dec_sub: %s' % usage_dec_sub)
        print('  att_reg_weight: %s' % str(att_reg_weight))
        print('  main_loss_weight: %s' % str(main_loss_weight))
        print('  ctc_loss_weight_sub: %s' % str(ctc_loss_weight_sub))
        print('  dec_attend_temperature: %s' % str(dec_attend_temperature))
        print('  dec_sigmoid_smoothing: %s' % str(dec_sigmoid_smoothing))
        print('  backward_sub: %s' % str(backward_sub))
        print('  num_heads: %s' % str(num_heads))
        print('  second_pass: %s' % str(second_pass))
        print('  relax_context_vec_dec: %s' % str(relax_context_vec_dec))
        print('==================================================')

        # Load batch data
        splice = 1
        num_stack = 1
        xs, ys, ys_sub, x_lens, y_lens, y_lens_sub = generate_data(
            label_type='word_char',
            batch_size=2,
            num_stack=num_stack,
            splice=splice)

        # Load model
        model = NestedAttentionSeq2seq(
            input_size=xs.shape[-1] // splice // num_stack,  # 120
            encoder_type='lstm',
            encoder_bidirectional=True,
            encoder_num_units=256,
            encoder_num_proj=0,
            encoder_num_layers=2,
            encoder_num_layers_sub=2,
            attention_type='location',
            attention_dim=128,
            decoder_type='lstm',
            decoder_num_units=256,
            decoder_num_layers=1,
            decoder_num_units_sub=256,
            decoder_num_layers_sub=1,
            embedding_dim=64,
            embedding_dim_sub=32,
            dropout_input=0.1,
            dropout_encoder=0.1,
            dropout_decoder=0.1,
            dropout_embedding=0.1,
            main_loss_weight=0.8,
            sub_loss_weight=0.2 if ctc_loss_weight_sub == 0 else 0,
            num_classes=11,
            num_classes_sub=27 if not second_pass else 11,
            parameter_init_distribution='uniform',
            parameter_init=0.1,
            recurrent_weight_orthogonal=False,
            init_forget_gate_bias_with_one=True,
            subsample_list=[True, False],
            subsample_type='drop',
            init_dec_state='first',
            sharpening_factor=1,
            logits_temperature=1,
            sigmoid_smoothing=False,
            ctc_loss_weight_sub=ctc_loss_weight_sub,
            attention_conv_num_channels=10,
            attention_conv_width=201,
            num_stack=num_stack,
            splice=1,
            conv_channels=[],
            conv_kernel_sizes=[],
            conv_strides=[],
            poolings=[],
            batch_norm=False,
            scheduled_sampling_prob=0.1,
            scheduled_sampling_max_step=200,
            label_smoothing_prob=0.1,
            weight_noise_std=0,
            encoder_residual=False,
            encoder_dense_residual=False,
            decoder_residual=False,
            decoder_dense_residual=False,
            decoding_order='attend_generate_update',
            # decoding_order='attend_update_generate',
            # decoding_order='conditional',
            bottleneck_dim=256,
            bottleneck_dim_sub=256,
            backward_sub=backward_sub,
            num_heads=num_heads,
            num_heads_sub=num_heads,
            num_heads_dec=num_heads,
            usage_dec_sub=usage_dec_sub,
            att_reg_weight=att_reg_weight,
            dec_attend_temperature=dec_attend_temperature,
            dec_sigmoid_smoothing=dec_attend_temperature,
            relax_context_vec_dec=relax_context_vec_dec,
            dec_attention_type='location')

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            print("%s %d" % (name, num_params))
        print("Total %.3f M parameters" % (model.total_parameters / 1000000))

        # Define optimizer
        learning_rate = 1e-3
        model.set_optimizer('adam',
                            learning_rate_init=learning_rate,
                            weight_decay=1e-6,
                            lr_schedule=False,
                            factor=0.1,
                            patience_epoch=5)

        # Define learning rate controller
        lr_controller = Controller(learning_rate_init=learning_rate,
                                   backend='pytorch',
                                   decay_start_epoch=20,
                                   decay_rate=0.9,
                                   decay_patient_epoch=10,
                                   lower_better=True)

        # GPU setting
        model.set_cuda(deterministic=False, benchmark=True)

        # Train model
        max_step = 300
        start_time_step = time.time()
        for step in range(max_step):

            # Step for parameter update
            model.optimizer.zero_grad()
            if second_pass:
                loss = model(xs, ys, x_lens, y_lens)
            else:
                loss, loss_main, loss_sub = model(
                    xs, ys, x_lens, y_lens, ys_sub, y_lens_sub)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            model.optimizer.step()

            if (step + 1) % 10 == 0:
                # Compute loss
                if second_pass:
                    loss = model(xs, ys, x_lens, y_lens, is_eval=True)
                else:
                    loss, loss_main, loss_sub = model(
                        xs, ys, x_lens, y_lens, ys_sub, y_lens_sub, is_eval=True)

                best_hyps, _, best_hyps_sub, _, perm_idx = model.decode(
                    xs, x_lens, beam_width=1,
                    max_decode_len=30,
                    max_decode_len_sub=60)

                str_hyp = idx2word(best_hyps[0][:-1])
                str_ref = idx2word(ys[0])
                if second_pass:
                    str_hyp_sub = idx2word(best_hyps_sub[0][:-1])
                    str_ref_sub = idx2word(ys[0])
                else:
                    str_hyp_sub = idx2char(best_hyps_sub[0][:-1])
                    str_ref_sub = idx2char(ys_sub[0])

                # Compute accuracy
                try:
                    wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                               hyp=str_hyp.split('_'),
                                               normalize=True)
                    if second_pass:
                        cer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                                   hyp=str_hyp_sub.split('_'),
                                                   normalize=True)
                    else:
                        cer, _, _, _ = compute_wer(
                            ref=list(str_ref_sub.replace('_', '')),
                            hyp=list(str_hyp_sub.replace('_', '')),
                            normalize=True)
                except:
                    wer = 1
                    cer = 1

                duration_step = time.time() - start_time_step
                if second_pass:
                    print('Step %d: loss=%.3f / wer=%.3f / cer=%.3f / lr=%.5f (%.3f sec)' %
                          (step + 1, loss, wer, cer, learning_rate, duration_step))
                else:
                    print('Step %d: loss=%.3f(%.3f/%.3f) / wer=%.3f / cer=%.3f / lr=%.5f (%.3f sec)' %
                          (step + 1, loss, loss_main, loss_sub,
                           wer, cer, learning_rate, duration_step))

                start_time_step = time.time()

                # Visualize
                print('Ref: %s' % str_ref)
                print('Hyp (word): %s' % str_hyp)
                print('Hyp (char): %s' % str_hyp_sub)

                if cer < 0.1:
                    print('Modle is Converged.')
                    break

                # Update learning rate
                model.optimizer, learning_rate = lr_controller.decay_lr(
                    optimizer=model.optimizer,
                    learning_rate=learning_rate,
                    epoch=step,
                    value=wer)
Пример #14
0
    def check_encode(self, model_type, label_type):
        print('----- ' + model_type + ', ' + label_type + ' -----')
        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 4
            inputs, _, inputs_seq_len, target_len = generate_data(
                label_type=label_type,
                model='attention',
                batch_size=batch_size)

            # Define model
            frame_num = inputs[0].shape[0]
            input_size = inputs[0].shape[1]
            inputs_pl = tf.placeholder(tf.float32,
                                       shape=[None, None, input_size],
                                       name='inputs')
            inputs_seq_len_pl = tf.placeholder(tf.int64,
                                               shape=[None],
                                               name='inputs_seq_len')
            keep_prob_input_pl = tf.placeholder(tf.float32,
                                                name='keep_prob_input')
            keep_prob_hidden_pl = tf.placeholder(tf.float32,
                                                 name='keep_prob_hidden')

            encoder = load(model_type)(num_unit=256,
                                       num_layer=5,
                                       parameter_init=0.1,
                                       clip_activation=5.0,
                                       num_proj=None)
            encoder_outputs_op = encoder(inputs=inputs_pl,
                                         inputs_seq_len=inputs_seq_len_pl,
                                         keep_prob_input=keep_prob_input_pl,
                                         keep_prob_hidden=keep_prob_hidden_pl)

            feed_dict = {
                inputs_pl: inputs,
                inputs_seq_len_pl: inputs_seq_len,
                keep_prob_input_pl: 1.0,
                keep_prob_hidden_pl: 1.0
            }

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)
                encoder_outputs = sess.run(encoder_outputs_op,
                                           feed_dict=feed_dict)

                if model_type == 'blstm_encoder':
                    # Pick up the final layer
                    outputs = encoder_outputs.outputs
                    (final_state_fw,
                     final_state_bw) = encoder_outputs.final_state
                    attention_values = encoder_outputs.attention_values
                    attention_values_length = encoder_outputs.attention_values_length

                    self.assertEqual(
                        (batch_size, frame_num, encoder.num_unit * 2),
                        outputs.shape)
                    self.assertEqual((batch_size, encoder.num_unit),
                                     final_state_fw.c.shape)
                    self.assertEqual((batch_size, encoder.num_unit),
                                     final_state_bw.c.shape)
                    self.assertEqual(
                        (batch_size, frame_num, encoder.num_unit * 2),
                        attention_values.shape)
                    self.assertEqual(frame_num, attention_values_length[0])

                elif model_type == 'lstm_encoder':
                    # Pick up the final layer
                    outputs = encoder_outputs.outputs
                    final_state_fw = encoder_outputs.final_state[-1]
                    attention_values = encoder_outputs.attention_values
                    attention_values_length = encoder_outputs.attention_values_length

                    self.assertEqual((batch_size, frame_num, encoder.num_unit),
                                     outputs.shape)
                    self.assertEqual((batch_size, encoder.num_unit),
                                     final_state_fw.c.shape)
                    self.assertEqual((batch_size, frame_num, encoder.num_unit),
                                     attention_values.shape)
                    self.assertEqual(frame_num, attention_values_length[0])

                elif model_type == 'bgru_encoder':
                    # Pick up the final layer
                    outputs = encoder_outputs.outputs
                    (final_state_fw,
                     final_state_bw) = encoder_outputs.final_state
                    attention_values = encoder_outputs.attention_values
                    attention_values_length = encoder_outputs.attention_values_length

                    self.assertEqual(
                        (batch_size, frame_num, encoder.num_unit * 2),
                        outputs.shape)
                    self.assertEqual((batch_size, encoder.num_unit),
                                     final_state_fw.shape)
                    self.assertEqual((batch_size, encoder.num_unit),
                                     final_state_bw.shape)
                    self.assertEqual(
                        (batch_size, frame_num, encoder.num_unit * 2),
                        attention_values.shape)
                    self.assertEqual(frame_num, attention_values_length[0])

                elif model_type == 'gru_encoder':
                    # Pick up the final layer
                    outputs = encoder_outputs.outputs
                    final_state_fw = encoder_outputs.final_state[-1]
                    attention_values = encoder_outputs.attention_values
                    attention_values_length = encoder_outputs.attention_values_length

                    self.assertEqual((batch_size, frame_num, encoder.num_unit),
                                     outputs.shape)
                    self.assertEqual((batch_size, encoder.num_unit),
                                     final_state_fw.shape)
                    self.assertEqual((batch_size, frame_num, encoder.num_unit),
                                     attention_values.shape)
                    self.assertEqual(frame_num, attention_values_length[0])
    def check(self, decoder_type):

        print('==================================================')
        print('  decoder_type: %s' % decoder_type)
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 2
            num_stack = 2
            inputs, labels, inputs_seq_len = generate_data(
                label_type='character',
                model='ctc',
                batch_size=batch_size,
                num_stack=num_stack,
                splice=1)
            max_time = inputs.shape[1]

            # Define model graph
            model = CTC(encoder_type='blstm',
                        input_size=inputs[0].shape[-1],
                        splice=1,
                        num_stack=num_stack,
                        num_units=256,
                        num_layers=2,
                        num_classes=27,
                        lstm_impl='LSTMBlockCell',
                        parameter_init=0.1,
                        clip_grad_norm=5.0,
                        clip_activation=50,
                        num_proj=256,
                        weight_decay=1e-6)

            # Define placeholders
            model.create_placeholders()

            # Add to the graph each operation
            _, logits = model.compute_loss(
                model.inputs_pl_list[0],
                model.labels_pl_list[0],
                model.inputs_seq_len_pl_list[0],
                model.keep_prob_pl_list[0])
            beam_width = 20 if 'beam_search' in decoder_type else 1
            decode_op = model.decoder(logits,
                                      model.inputs_seq_len_pl_list[0],
                                      beam_width=beam_width)
            ler_op = model.compute_ler(decode_op, model.labels_pl_list[0])
            posteriors_op = model.posteriors(logits, blank_prior=1)

            if decoder_type == 'np_greedy':
                decoder = GreedyDecoder(blank_index=model.num_classes)
            elif decoder_type == 'np_beam_search':
                decoder = BeamSearchDecoder(space_index=26,
                                            blank_index=model.num_classes - 1)

            # Make feed dict
            feed_dict = {
                model.inputs_pl_list[0]: inputs,
                model.labels_pl_list[0]: list2sparsetensor(labels,
                                                           padded_value=-1),
                model.inputs_seq_len_pl_list[0]: inputs_seq_len,
                model.keep_prob_pl_list[0]: 1.0
            }

            # Create a saver for writing training checkpoints
            saver = tf.train.Saver()

            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state('./')

                # If check point exists
                if ckpt:
                    model_path = ckpt.model_checkpoint_path
                    saver.restore(sess, model_path)
                    print("Model restored: " + model_path)
                else:
                    raise ValueError('There are not any checkpoints.')

                if decoder_type in ['tf_greedy', 'tf_beam_search']:
                    # Decode
                    labels_pred_st = sess.run(decode_op, feed_dict=feed_dict)
                    labels_pred = sparsetensor2list(
                        labels_pred_st, batch_size=batch_size)

                    # Compute accuracy
                    cer = sess.run(ler_op, feed_dict=feed_dict)
                else:
                    # Compute CTC posteriors
                    probs = sess.run(posteriors_op, feed_dict=feed_dict)
                    probs = probs.reshape(-1, max_time, model.num_classes)

                    if decoder_type == 'np_greedy':
                        # Decode
                        labels_pred = decoder(probs=probs,
                                              seq_len=inputs_seq_len)

                    elif decoder_type == 'np_beam_search':
                        # Decode
                        labels_pred, scores = decoder(probs=probs,
                                                      seq_len=inputs_seq_len,
                                                      beam_width=beam_width)

                    # Compute accuracy
                    cer = compute_cer(str_pred=idx2alpha(labels_pred[0]),
                                      str_true=idx2alpha(labels[0]),
                                      normalize=True)

                # Visualize
                print('CER: %.3f %%' % (cer * 100))
                print('Ref: %s' % idx2alpha(labels[0]))
                print('Hyp: %s' % idx2alpha(labels_pred[0]))
    def check(self,
              encoder_type,
              bidirectional=False,
              batch_first=True,
              subsample_type='concat',
              conv=False,
              merge_bidirectional=False,
              projection=False,
              residual=False,
              dense_residual=False):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  bidirectional: %s' % str(bidirectional))
        print('  batch_first: %s' % str(batch_first))
        print('  subsample_type: %s' % subsample_type)
        print('  conv: %s' % str(conv))
        print('  merge_bidirectional: %s' % str(merge_bidirectional))
        print('  projection: %s' % str(projection))
        print('  residual: %s' % str(residual))
        print('  dense_residual: %s' % str(dense_residual))
        print('==================================================')

        if conv:
            # pattern 1
            # conv_channels = [32, 32]
            # conv_kernel_sizes = [[41, 11], [21, 11]]
            # conv_strides = [[2, 2], [2, 1]]
            # poolings = [[], []]

            # pattern 2 (VGG like)
            conv_channels = [64, 64]
            conv_kernel_sizes = [[3, 3], [3, 3]]
            conv_strides = [[1, 1], [1, 1]]
            poolings = [[2, 2], [2, 2]]
        else:
            conv_channels = []
            conv_kernel_sizes = []
            conv_strides = []
            poolings = []

        # Load batch data
        batch_size = 4
        splice = 1
        num_stack = 1
        xs, _, x_lens, _ = generate_data(batch_size=batch_size,
                                         num_stack=num_stack,
                                         splice=splice)

        # Wrap by Tensor
        xs = torch.from_numpy(xs)
        x_lens = torch.from_numpy(x_lens)

        # Load encoder
        encoder = load(encoder_type=encoder_type)

        # Initialize encoder
        if encoder_type in ['lstm', 'gru', 'rnn']:
            encoder = encoder(
                input_size=xs.size(-1) // splice // num_stack,  # 120
                rnn_type=encoder_type,
                bidirectional=bidirectional,
                num_units=256,
                num_proj=256 if projection else 0,
                num_layers=6,
                num_layers_sub=4,
                dropout_input=0.2,
                dropout_hidden=0.2,
                subsample_list=[False, True, True, False, False, False],
                subsample_type=subsample_type,
                batch_first=batch_first,
                merge_bidirectional=merge_bidirectional,
                splice=splice,
                num_stack=num_stack,
                conv_channels=conv_channels,
                conv_kernel_sizes=conv_kernel_sizes,
                conv_strides=conv_strides,
                poolings=poolings,
                batch_norm=True,
                residual=residual,
                dense_residual=dense_residual)
        else:
            raise NotImplementedError

        max_time = xs.size(1)
        if conv:
            max_time = encoder.conv.get_conv_out_size(max_time, 1)
        max_time_sub = max_time // \
            (2 ** sum(encoder.subsample_list[:encoder.num_layers_sub]))
        max_time //= (2**sum(encoder.subsample_list))
        if subsample_type == 'drop':
            max_time_sub = math.ceil(max_time_sub)
            max_time = math.ceil(max_time)
        elif subsample_type == 'concat':
            max_time_sub = int(max_time_sub)
            max_time = int(max_time)

        outputs, _, outputs_sub, _, perm_indices = encoder(xs, x_lens)

        print('----- outputs -----')
        print(xs.size())
        print(outputs_sub.size())
        print(outputs.size())
        num_directions = 2 if bidirectional and not merge_bidirectional else 1
        if batch_first:
            self.assertEqual(
                (batch_size, max_time_sub, encoder.num_units * num_directions),
                outputs_sub.size())
            self.assertEqual(
                (batch_size, max_time, encoder.num_units * num_directions),
                outputs.size())
        else:
            self.assertEqual(
                (max_time_sub, batch_size, encoder.num_units * num_directions),
                outputs_sub.size())
            self.assertEqual(
                (max_time, batch_size, encoder.num_units * num_directions),
                outputs.size())
    def check(self,
              encoder_type,
              decoder_type,
              bidirectional=False,
              attention_type='location',
              label_type='char',
              subsample=False,
              projection=False,
              init_dec_state='first',
              ctc_loss_weight=0,
              conv=False,
              batch_norm=False,
              residual=False,
              dense_residual=False,
              decoding_order='bahdanau_attention',
              backward_loss_weight=0,
              num_heads=1,
              beam_width=1):

        print('==================================================')
        print('  label_type: %s' % label_type)
        print('  encoder_type: %s' % encoder_type)
        print('  bidirectional: %s' % str(bidirectional))
        print('  projection: %s' % str(projection))
        print('  decoder_type: %s' % decoder_type)
        print('  init_dec_state: %s' % init_dec_state)
        print('  attention_type: %s' % attention_type)
        print('  subsample: %s' % str(subsample))
        print('  ctc_loss_weight: %s' % str(ctc_loss_weight))
        print('  conv: %s' % str(conv))
        print('  batch_norm: %s' % str(batch_norm))
        print('  residual: %s' % str(residual))
        print('  dense_residual: %s' % str(dense_residual))
        print('  decoding_order: %s' % decoding_order)
        print('  backward_loss_weight: %s' % str(backward_loss_weight))
        print('  num_heads: %s' % str(num_heads))
        print('  beam_width: %s' % str(beam_width))
        print('==================================================')

        if conv or encoder_type == 'cnn':
            # pattern 1
            # conv_channels = [32, 32]
            # conv_kernel_sizes = [[41, 11], [21, 11]]
            # conv_strides = [[2, 2], [2, 1]]
            # poolings = [[], []]

            # pattern 2 (VGG like)
            conv_channels = [64, 64]
            conv_kernel_sizes = [[3, 3], [3, 3]]
            conv_strides = [[1, 1], [1, 1]]
            poolings = [[2, 2], [2, 2]]
        else:
            conv_channels = []
            conv_kernel_sizes = []
            conv_strides = []
            poolings = []

        # Load batch data
        splice = 1
        num_stack = 1 if subsample or conv or encoder_type == 'cnn' else 3
        xs, ys, x_lens, y_lens = generate_data(label_type=label_type,
                                               batch_size=2,
                                               num_stack=num_stack,
                                               splice=splice)

        if label_type == 'char':
            num_classes = 27
            map_fn = idx2char
        elif label_type == 'word':
            num_classes = 11
            map_fn = idx2word

        # Load model
        model = AttentionSeq2seq(
            input_size=xs.shape[-1] // splice // num_stack,  # 120
            encoder_type=encoder_type,
            encoder_bidirectional=bidirectional,
            encoder_num_units=256,
            encoder_num_proj=256 if projection else 0,
            encoder_num_layers=1 if not subsample else 2,
            attention_type=attention_type,
            attention_dim=128,
            decoder_type=decoder_type,
            decoder_num_units=256,
            decoder_num_layers=1,
            embedding_dim=32,
            dropout_input=0.1,
            dropout_encoder=0.1,
            dropout_decoder=0.1,
            dropout_embedding=0.1,
            num_classes=num_classes,
            parameter_init_distribution='uniform',
            parameter_init=0.1,
            recurrent_weight_orthogonal=False,
            init_forget_gate_bias_with_one=True,
            subsample_list=[] if not subsample else [True, False],
            subsample_type='concat' if not subsample else subsample,
            bridge_layer=True,
            init_dec_state=init_dec_state,
            sharpening_factor=1,
            logits_temperature=1,
            sigmoid_smoothing=False,
            coverage_weight=0,
            ctc_loss_weight=ctc_loss_weight,
            attention_conv_num_channels=10,
            attention_conv_width=201,
            num_stack=num_stack,
            splice=splice,
            input_channel=3,
            conv_channels=conv_channels,
            conv_kernel_sizes=conv_kernel_sizes,
            conv_strides=conv_strides,
            poolings=poolings,
            activation='relu',
            batch_norm=batch_norm,
            scheduled_sampling_prob=0.1,
            scheduled_sampling_max_step=200,
            label_smoothing_prob=0.1,
            weight_noise_std=1e-9,
            encoder_residual=residual,
            encoder_dense_residual=dense_residual,
            decoder_residual=residual,
            decoder_dense_residual=dense_residual,
            decoding_order=decoding_order,
            bottleneck_dim=256,
            backward_loss_weight=backward_loss_weight,
            num_heads=num_heads)

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            print("%s %d" % (name, num_params))
        print("Total %.3f M parameters" % (model.total_parameters / 1000000))

        # Define optimizer
        learning_rate = 1e-3
        model.set_optimizer('adam',
                            learning_rate_init=learning_rate,
                            weight_decay=1e-8,
                            lr_schedule=False,
                            factor=0.1,
                            patience_epoch=5)

        # Define learning rate controller
        lr_controller = Controller(learning_rate_init=learning_rate,
                                   backend='pytorch',
                                   decay_start_epoch=20,
                                   decay_rate=0.9,
                                   decay_patient_epoch=10,
                                   lower_better=True)

        # GPU setting
        model.set_cuda(deterministic=False, benchmark=True)

        # Train model
        max_step = 300
        start_time_step = time.time()
        for step in range(max_step):

            # Step for parameter update
            model.optimizer.zero_grad()
            loss = model(xs, ys, x_lens, y_lens)
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            torch.nn.utils.clip_grad_norm(model.parameters(), 5)
            model.optimizer.step()

            # Inject Gaussian noise to all parameters
            # if loss.item() < 50:
            if loss.data[0] < 50:
                model.weight_noise_injection = True

            if (step + 1) % 10 == 0:
                # Compute loss
                loss = model(xs, ys, x_lens, y_lens, is_eval=True)

                # Decode
                best_hyps, _, perm_idx = model.decode(xs,
                                                      x_lens,
                                                      beam_width,
                                                      max_decode_len=60)

                str_ref = map_fn(ys[0])
                str_hyp = map_fn(best_hyps[0][:-1])

                # Compute accuracy
                try:
                    if label_type == 'char':
                        ler, _, _, _ = compute_wer(
                            ref=list(str_ref.replace('_', '')),
                            hyp=list(str_hyp.replace('_', '')),
                            normalize=True)
                    elif label_type == 'word':
                        ler, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                                   hyp=str_hyp.split('_'),
                                                   normalize=True)
                except:
                    ler = 1

                duration_step = time.time() - start_time_step
                print('Step %d: loss=%.3f / ler=%.3f / lr=%.5f (%.3f sec)' %
                      (step + 1, loss, ler, learning_rate, duration_step))
                start_time_step = time.time()

                # Visualize
                print('Ref: %s' % str_ref)
                print('Hyp: %s' % str_hyp)

                # Decode by the CTC decoder
                if model.ctc_loss_weight >= 0.1:
                    best_hyps_ctc, perm_idx = model.decode_ctc(xs,
                                                               x_lens,
                                                               beam_width=1)
                    str_pred_ctc = map_fn(best_hyps_ctc[0])
                    print('Hyp (CTC): %s' % str_pred_ctc)

                if ler < 0.1:
                    print('Modle is Converged.')
                    break

                # Update learning rate
                model.optimizer, learning_rate = lr_controller.decay_lr(
                    optimizer=model.optimizer,
                    learning_rate=learning_rate,
                    epoch=step,
                    value=ler)
    def check(self,
              encoder_type,
              bidirectional=False,
              subsample=False,
              projection=False,
              conv=False,
              batch_norm=False,
              activation='relu',
              encoder_residual=False,
              encoder_dense_residual=False,
              label_smoothing=False):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  bidirectional: %s' % str(bidirectional))
        print('  projection: %s' % str(projection))
        print('  subsample: %s' % str(subsample))
        print('  conv: %s' % str(conv))
        print('  batch_norm: %s' % str(batch_norm))
        print('  encoder_residual: %s' % str(encoder_residual))
        print('  encoder_dense_residual: %s' % str(encoder_dense_residual))
        print('  label_smoothing: %s' % str(label_smoothing))
        print('==================================================')

        if conv or encoder_type == 'cnn':
            # pattern 1
            # conv_channels = [32, 32]
            # conv_kernel_sizes = [[41, 11], [21, 11]]
            # conv_strides = [[2, 2], [2, 1]]
            # poolings = [[], []]

            # pattern 2 (VGG like)
            conv_channels = [64, 64]
            conv_kernel_sizes = [[3, 3], [3, 3]]
            conv_strides = [[1, 1], [1, 1]]
            poolings = [[2, 2], [2, 2]]

            fc_list = [786, 786]
        else:
            conv_channels = []
            conv_kernel_sizes = []
            conv_strides = []
            poolings = []
            fc_list = []

        # Load batch data
        num_stack = 1 if subsample or conv or encoder_type == 'cnn' else 2
        splice = 1
        xs, ys, ys_sub, x_lens, y_lens, y_lens_sub = generate_data(
            label_type='word_char',
            batch_size=2,
            num_stack=num_stack,
            splice=splice)

        num_classes = 11
        num_classes_sub = 27

        # Load model
        model = HierarchicalCTC(
            input_size=xs.shape[-1] // splice // num_stack,  # 120
            encoder_type=encoder_type,
            encoder_bidirectional=bidirectional,
            encoder_num_units=256,
            encoder_num_proj=256 if projection else 0,
            encoder_num_layers=2,
            encoder_num_layers_sub=1,
            fc_list=fc_list,
            fc_list_sub=fc_list,
            dropout_input=0.1,
            dropout_encoder=0.1,
            main_loss_weight=0.8,
            sub_loss_weight=0.2,
            num_classes=num_classes,
            num_classes_sub=num_classes_sub,
            parameter_init_distribution='uniform',
            parameter_init=0.1,
            recurrent_weight_orthogonal=False,
            init_forget_gate_bias_with_one=True,
            subsample_list=[] if not subsample else [True, False],
            num_stack=num_stack,
            splice=splice,
            input_channel=3,
            conv_channels=conv_channels,
            conv_kernel_sizes=conv_kernel_sizes,
            conv_strides=conv_strides,
            poolings=poolings,
            batch_norm=batch_norm,
            label_smoothing_prob=0.1 if label_smoothing else 0,
            weight_noise_std=0,
            encoder_residual=encoder_residual,
            encoder_dense_residual=encoder_dense_residual)

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            print("%s %d" % (name, num_params))
        print("Total %.3f M parameters" % (model.total_parameters / 1000000))

        # Define optimizer
        learning_rate = 1e-3
        model.set_optimizer('adam',
                            learning_rate_init=learning_rate,
                            weight_decay=1e-6,
                            lr_schedule=False,
                            factor=0.1,
                            patience_epoch=5)

        # Define learning rate controller
        lr_controller = Controller(learning_rate_init=learning_rate,
                                   backend='pytorch',
                                   decay_start_epoch=20,
                                   decay_rate=0.9,
                                   decay_patient_epoch=10,
                                   lower_better=True)

        # GPU setting
        model.set_cuda(deterministic=False, benchmark=True)

        # Train model
        max_step = 300
        start_time_step = time.time()
        for step in range(max_step):

            # Step for parameter update
            model.optimizer.zero_grad()
            loss, loss_main, loss_sub = model(xs, ys, x_lens, y_lens, ys_sub,
                                              y_lens_sub)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            model.optimizer.step()

            if (step + 1) % 10 == 0:
                # Compute loss
                loss, loss_main, loss_sub = model(xs,
                                                  ys,
                                                  x_lens,
                                                  y_lens,
                                                  ys_sub,
                                                  y_lens_sub,
                                                  is_eval=True)

                # Decode
                best_hyps, _, _ = model.decode(xs,
                                               x_lens,
                                               beam_width=2,
                                               task_index=0)
                best_hyps_sub, _, _ = model.decode(xs,
                                                   x_lens,
                                                   beam_width=2,
                                                   task_index=1)

                str_ref = idx2word(ys[0, :y_lens[0]])
                str_hyp = idx2word(best_hyps[0])
                str_ref_sub = idx2char(ys_sub[0, :y_lens_sub[0]])
                str_hyp_sub = idx2char(best_hyps_sub[0])

                # Compute accuracy
                try:
                    wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                               hyp=str_hyp.split('_'),
                                               normalize=True)
                    cer, _, _, _ = compute_wer(
                        ref=list(str_ref_sub.replace('_', '')),
                        hyp=list(str_hyp_sub.replace('_', '')),
                        normalize=True)
                except:
                    wer = 1
                    cer = 1

                duration_step = time.time() - start_time_step
                print(
                    'Step %d: loss=%.3f(%.3f/%.3f) / wer=%.3f / cer=%.3f / lr=%.5f (%.3f sec)'
                    % (step + 1, loss, loss_main, loss_sub, wer, cer,
                       learning_rate, duration_step))
                start_time_step = time.time()

                # Visualize
                print('Ref: %s' % str_ref)
                print('Hyp (word): %s' % str_hyp)
                print('Hyp (char): %s' % str_hyp_sub)

                if cer < 0.1:
                    print('Modle is Converged.')
                    break

                # Update learning rate
                model.optimizer, learning_rate = lr_controller.decay_lr(
                    optimizer=model.optimizer,
                    learning_rate=learning_rate,
                    epoch=step,
                    value=wer)
    def check(self,
              encoder_type,
              bidirectional=False,
              conv=False,
              merge_bidirectional=False,
              projection=False,
              residual=False,
              dense_residual=False):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  bidirectional: %s' % str(bidirectional))
        print('  conv: %s' % str(conv))
        print('  merge_bidirectional: %s' % str(merge_bidirectional))
        print('  projection: %s' % str(projection))
        print('  residual: %s' % str(residual))
        print('  dense_residual: %s' % str(dense_residual))
        print('==================================================')

        if conv:
            # pattern 1
            # conv_channels = [32, 32]
            # conv_kernel_sizes = [[41, 11], [21, 11]]
            # conv_strides = [[2, 2], [2, 1]]
            # poolings = [[], []]

            # pattern 2 (VGG like)
            conv_channels = [64, 64]
            conv_kernel_sizes = [[3, 3], [3, 3]]
            conv_strides = [[1, 1], [1, 1]]
            poolings = [[2, 2], [2, 2]]
        else:
            conv_channels = []
            conv_kernel_sizes = []
            conv_strides = []
            poolings = []

        # Load batch data
        batch_size = 4
        splice = 1
        num_stack = 1
        xs, _, x_lens, _ = generate_data(batch_size=batch_size,
                                         num_stack=num_stack,
                                         splice=splice,
                                         backend='chainer')

        # Wrap by Variable
        xs = [chainer.Variable(x, requires_grad=False) for x in xs]

        # Load encoder
        encoder = load(encoder_type=encoder_type)

        # Initialize encoder
        encoder = encoder(
            input_size=xs[0].shape[-1] // splice // num_stack,  # 120
            rnn_type=encoder_type,
            bidirectional=bidirectional,
            num_units=256,
            num_proj=256 if projection else 0,
            num_layers=5,
            dropout_input=0.2,
            dropout_hidden=0.2,
            subsample_list=[],
            merge_bidirectional=merge_bidirectional,
            splice=splice,
            num_stack=num_stack,
            conv_channels=conv_channels,
            conv_kernel_sizes=conv_kernel_sizes,
            conv_strides=conv_strides,
            poolings=poolings,
            batch_norm=True,
            residual=residual,
            dense_residual=dense_residual)

        max_time = xs[0].shape[0]
        if conv:
            max_time = encoder.conv.get_conv_out_size(max_time, 1)

        outputs, _ = encoder(xs, x_lens)

        print('----- outputs -----')
        print((len(outputs), outputs[0].shape[0], outputs[0].shape[1]))
        num_directions = 2 if bidirectional and not merge_bidirectional else 1
        self.assertEqual(
            (batch_size, max_time, encoder.num_units * num_directions),
            (len(outputs), outputs[0].shape[0], outputs[0].shape[1]))
    def check(self, encoder_type, bidirectional=False, label_type='char',
              subsample=False,  projection=False,
              conv=False, batch_norm=False, activation='relu',
              encoder_residual=False, encoder_dense_residual=False,
              label_smoothing=False):

        print('==================================================')
        print('  label_type: %s' % label_type)
        print('  encoder_type: %s' % encoder_type)
        print('  bidirectional: %s' % str(bidirectional))
        print('  projection: %s' % str(projection))
        print('  subsample: %s' % str(subsample))
        print('  conv: %s' % str(conv))
        print('  batch_norm: %s' % str(batch_norm))
        print('  activation: %s' % activation)
        print('  encoder_residual: %s' % str(encoder_residual))
        print('  encoder_dense_residual: %s' % str(encoder_dense_residual))
        print('  label_smoothing: %s' % str(label_smoothing))
        print('==================================================')

        if conv or encoder_type == 'cnn':
            # pattern 1
            # conv_channels = [32, 32]
            # conv_kernel_sizes = [[41, 11], [21, 11]]
            # conv_strides = [[2, 2], [2, 1]]
            # poolings = [[], []]

            # pattern 2 (VGG like)
            conv_channels = [64, 64]
            conv_kernel_sizes = [[3, 3], [3, 3]]
            conv_strides = [[1, 1], [1, 1]]
            poolings = [[2, 2], [2, 2]]

            fc_list = [786, 786]
        else:
            conv_channels = []
            conv_kernel_sizes = []
            conv_strides = []
            poolings = []
            fc_list = []

        # Load batch data
        splice = 1
        num_stack = 1 if subsample or conv or encoder_type == 'cnn' else 2
        xs, ys, x_lens, y_lens = generate_data(
            label_type=label_type,
            batch_size=2,
            num_stack=num_stack,
            splice=splice,
            backend='chainer')

        if label_type == 'char':
            num_classes = 27
            map_fn = idx2char
        elif label_type == 'word':
            num_classes = 11
            map_fn = idx2word

        # Load model
        model = CTC(
            input_size=xs[0].shape[-1] // splice // num_stack,  # 120
            encoder_type=encoder_type,
            encoder_bidirectional=bidirectional,
            encoder_num_units=256,
            encoder_num_proj=256 if projection else 0,
            encoder_num_layers=1 if not subsample else 2,
            fc_list=fc_list,
            dropout_input=0.1,
            dropout_encoder=0.1,
            num_classes=num_classes,
            parameter_init_distribution='uniform',
            parameter_init=0.1,
            recurrent_weight_orthogonal=False,
            init_forget_gate_bias_with_one=True,
            subsample_list=[] if not subsample else [True] * 2,
            num_stack=num_stack,
            splice=splice,
            input_channel=3,
            conv_channels=conv_channels,
            conv_kernel_sizes=conv_kernel_sizes,
            conv_strides=conv_strides,
            poolings=poolings,
            activation=activation,
            batch_norm=batch_norm,
            label_smoothing_prob=0.1 if label_smoothing else 0,
            weight_noise_std=0,
            encoder_residual=encoder_residual,
            encoder_dense_residual=encoder_dense_residual)

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            print("%s %d" % (name, num_params))
        print("Total %.3f M parameters" % (model.total_parameters / 1000000))

        # Define optimizer
        learning_rate = 1e-3
        model.set_optimizer(
            'adam',
            # 'adadelta',
            learning_rate_init=learning_rate,
            weight_decay=1e-6,
            clip_grad_norm=5,
            lr_schedule=None,
            factor=None,
            patience_epoch=None)

        # Define learning rate controller
        lr_controller = Controller(learning_rate_init=learning_rate,
                                   backend='chainer',
                                   decay_start_epoch=20,
                                   decay_rate=0.9,
                                   decay_patient_epoch=10,
                                   lower_better=True)

        # GPU setting
        model.set_cuda(deterministic=False, benchmark=True)

        # Train model
        max_step = 300
        start_time_step = time.time()
        for step in range(max_step):

            # Step for parameter update
            loss = model(xs, ys, x_lens, y_lens)
            model.optimizer.target.cleargrads()
            model.cleargrads()
            loss.backward()
            loss.unchain_backward()
            model.optimizer.update()

            # Inject Gaussian noise to all parameters

            if (step + 1) % 10 == 0:
                # Compute loss
                loss = model(xs, ys, x_lens, y_lens, is_eval=True)

                # Decode
                best_hyps, _,  _ = model.decode(xs, x_lens, beam_width=1)
                # TODO: fix beam search

                str_ref = map_fn(ys[0, :y_lens[0]])
                str_hyp = map_fn(best_hyps[0])

                # Compute accuracy
                try:
                    if label_type == 'char':
                        ler, _, _, _ = compute_wer(
                            ref=list(str_ref.replace('_', '')),
                            hyp=list(str_hyp.replace('_', '')),
                            normalize=True)
                    elif label_type == 'word':
                        ler, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                                   hyp=str_hyp.split('_'),
                                                   normalize=True)
                except:
                    ler = 1

                duration_step = time.time() - start_time_step
                print('Step %d: loss=%.3f / ler=%.3f / lr=%.5f (%.3f sec)' %
                      (step + 1, loss, ler, learning_rate, duration_step))
                start_time_step = time.time()

                # Visualize
                print('Ref: %s' % str_ref)
                print('Hyp: %s' % str_hyp)

                if ler < 0.05:
                    print('Modle is Converged.')
                    break

                # Update learning rate
                model.optimizer, learning_rate = lr_controller.decay_lr(
                    optimizer=model.optimizer,
                    learning_rate=learning_rate,
                    epoch=step,
                    value=ler)
Пример #21
0
    def check(self, encoder_type, attention_type, label_type='character'):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  attention_type: %s' % attention_type)
        print('  label_type: %s' % label_type)
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 4
            inputs, labels, inputs_seq_len, labels_seq_len = generate_data(
                label_type=label_type,
                model='attention',
                batch_size=batch_size)

            # Define model graph
            num_classes = 27 if label_type == 'character' else 61
            model = AttentionSeq2Seq(input_size=inputs[0].shape[1],
                                     encoder_type=encoder_type,
                                     encoder_num_units=256,
                                     encoder_num_layers=2,
                                     encoder_num_proj=None,
                                     attention_type=attention_type,
                                     attention_dim=128,
                                     decoder_type='lstm',
                                     decoder_num_units=256,
                                     decoder_num_layers=1,
                                     embedding_dim=64,
                                     num_classes=num_classes,
                                     sos_index=num_classes,
                                     eos_index=num_classes + 1,
                                     max_decode_length=100,
                                     use_peephole=True,
                                     splice=1,
                                     parameter_init=0.1,
                                     clip_grad_norm=5.0,
                                     clip_activation_encoder=50,
                                     clip_activation_decoder=50,
                                     weight_decay=1e-8,
                                     time_major=True,
                                     sharpening_factor=1.0,
                                     logits_temperature=1.0)

            # Define placeholders
            model.create_placeholders()
            learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate')

            # Add to the graph each operation
            loss_op, logits, decoder_outputs_train, decoder_outputs_infer = model.compute_loss(
                model.inputs_pl_list[0],
                model.labels_pl_list[0],
                model.inputs_seq_len_pl_list[0],
                model.labels_seq_len_pl_list[0],
                model.keep_prob_encoder_pl_list[0],
                model.keep_prob_decoder_pl_list[0],
                model.keep_prob_embedding_pl_list[0])
            train_op = model.train(loss_op,
                                   optimizer='adam',
                                   learning_rate=learning_rate_pl)
            decode_op_train, decode_op_infer = model.decode(
                decoder_outputs_train, decoder_outputs_infer)
            ler_op = model.compute_ler(model.labels_st_true_pl,
                                       model.labels_st_pred_pl)

            # Define learning rate controller
            learning_rate = 1e-3
            lr_controller = Controller(learning_rate_init=learning_rate,
                                       decay_start_epoch=20,
                                       decay_rate=0.9,
                                       decay_patient_epoch=10,
                                       lower_better=True)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()),
                   "{:,}".format(total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                model.inputs_pl_list[0]: inputs,
                model.labels_pl_list[0]: labels,
                model.inputs_seq_len_pl_list[0]: inputs_seq_len,
                model.labels_seq_len_pl_list[0]: labels_seq_len,
                model.keep_prob_encoder_pl_list[0]: 0.8,
                model.keep_prob_decoder_pl_list[0]: 1.0,
                model.keep_prob_embedding_pl_list[0]: 1.0,
                learning_rate_pl: learning_rate
            }

            idx2phone = Idx2phone(map_file_path='./phone61.txt')

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Wrapper for tfdbg
                # sess = tf_debug.LocalCLIDebugWrapperSession(sess)

                # Train model
                max_steps = 1000
                start_time_step = time.time()
                for step in range(max_steps):

                    # Compute loss
                    _, loss_train = sess.run(
                        [train_op, loss_op], feed_dict=feed_dict)

                    # Gradient check
                    # grads = sess.run(model.clipped_grads,
                    #                  feed_dict=feed_dict)
                    # for grad in grads:
                    #     print(np.max(grad))

                    if (step + 1) % 10 == 0:
                        # Change to evaluation mode
                        feed_dict[model.keep_prob_encoder_pl_list[0]] = 1.0
                        feed_dict[model.keep_prob_decoder_pl_list[0]] = 1.0
                        feed_dict[model.keep_prob_embedding_pl_list[0]] = 1.0

                        # Predict class ids
                        predicted_ids_train, predicted_ids_infer = sess.run(
                            [decode_op_train, decode_op_infer],
                            feed_dict=feed_dict)

                        # Compute accuracy
                        try:
                            feed_dict_ler = {
                                model.labels_st_true_pl: list2sparsetensor(
                                    labels, padded_value=model.eos_index),
                                model.labels_st_pred_pl: list2sparsetensor(
                                    predicted_ids_infer, padded_value=model.eos_index)
                            }
                            ler_train = sess.run(
                                ler_op, feed_dict=feed_dict_ler)
                        except IndexError:
                            ler_train = 1

                        duration_step = time.time() - start_time_step
                        print('Step %d: loss = %.3f / ler = %.3f (%.3f sec) / lr = %.5f' %
                              (step + 1, loss_train, ler_train, duration_step, learning_rate))
                        start_time_step = time.time()

                        # Visualize
                        if label_type == 'character':
                            print('True            : %s' %
                                  idx2alpha(labels[0]))
                            print('Pred (Training) : <%s' %
                                  idx2alpha(predicted_ids_train[0]))
                            print('Pred (Inference): <%s' %
                                  idx2alpha(predicted_ids_infer[0]))
                        else:
                            print('True            : %s' %
                                  idx2phone(labels[0]))
                            print('Pred (Training) : < %s' %
                                  idx2phone(predicted_ids_train[0]))
                            print('Pred (Inference): < %s' %
                                  idx2phone(predicted_ids_infer[0]))

                        if ler_train < 0.1:
                            print('Model is Converged.')
                            break

                        # Update learning rate
                        learning_rate = lr_controller.decay_lr(
                            learning_rate=learning_rate,
                            epoch=step,
                            value=ler_train)
                        feed_dict[learning_rate_pl] = learning_rate
Пример #22
0
    def check(self, encoder_type, lstm_impl=None, time_major=False):

        print('==================================================')
        print('  encoder_type: %s' % encoder_type)
        print('  lstm_impl: %s' % lstm_impl)
        print('  time_major: %s' % time_major)
        print('==================================================')

        tf.reset_default_graph()
        with tf.Graph().as_default():
            # Load batch data
            batch_size = 4
            splice = 5 if encoder_type in [
                'vgg_blstm', 'vgg_lstm', 'vgg_wang', 'resnet_wang',
                'cldnn_wang', 'cnn_zhang'
            ] else 1
            num_stack = 2
            inputs, _, inputs_seq_len = generate_data(label_type='character',
                                                      model='ctc',
                                                      batch_size=batch_size,
                                                      num_stack=num_stack,
                                                      splice=splice)
            frame_num, input_size = inputs[0].shape

            # Define model graph
            if encoder_type in ['blstm', 'lstm']:
                encoder = load(encoder_type)(num_units=256,
                                             num_proj=None,
                                             num_layers=5,
                                             lstm_impl=lstm_impl,
                                             use_peephole=True,
                                             parameter_init=0.1,
                                             clip_activation=5,
                                             time_major=time_major)
            elif encoder_type in ['bgru', 'gru']:
                encoder = load(encoder_type)(num_units=256,
                                             num_layers=5,
                                             parameter_init=0.1,
                                             time_major=time_major)
            elif encoder_type in ['vgg_blstm', 'vgg_lstm', 'cldnn_wang']:
                encoder = load(encoder_type)(input_size=input_size // splice //
                                             num_stack,
                                             splice=splice,
                                             num_stack=num_stack,
                                             num_units=256,
                                             num_proj=None,
                                             num_layers=5,
                                             lstm_impl=lstm_impl,
                                             use_peephole=True,
                                             parameter_init=0.1,
                                             clip_activation=5,
                                             time_major=time_major)
            elif encoder_type in ['multitask_blstm', 'multitask_lstm']:
                encoder = load(encoder_type)(num_units=256,
                                             num_proj=None,
                                             num_layers_main=5,
                                             num_layers_sub=3,
                                             lstm_impl=lstm_impl,
                                             use_peephole=True,
                                             parameter_init=0.1,
                                             clip_activation=5,
                                             time_major=time_major)
            elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                encoder = load(encoder_type)(input_size=input_size // splice //
                                             num_stack,
                                             splice=splice,
                                             num_stack=num_stack,
                                             parameter_init=0.1,
                                             time_major=time_major)
                # NOTE: topology is pre-defined
            else:
                raise NotImplementedError

            # Create placeholders
            inputs_pl = tf.placeholder(tf.float32,
                                       shape=[None, None, input_size],
                                       name='inputs')
            inputs_seq_len_pl = tf.placeholder(tf.int32,
                                               shape=[None],
                                               name='inputs_seq_len')
            keep_prob_pl = tf.placeholder(tf.float32, name='keep_prob')

            # operation for forward computation
            if encoder_type in ['multitask_blstm', 'multitask_lstm']:
                hidden_states_op, final_state_op, hidden_states_sub_op, final_state_sub_op = encoder(
                    inputs=inputs_pl,
                    inputs_seq_len=inputs_seq_len_pl,
                    keep_prob=keep_prob_pl,
                    is_training=True)
            else:
                hidden_states_op, final_state_op = encoder(
                    inputs=inputs_pl,
                    inputs_seq_len=inputs_seq_len_pl,
                    keep_prob=keep_prob_pl,
                    is_training=True)

            # Add the variable initializer operation
            init_op = tf.global_variables_initializer()

            # Count total parameters
            parameters_dict, total_parameters = count_total_parameters(
                tf.trainable_variables())
            for parameter_name in sorted(parameters_dict.keys()):
                print("%s %d" %
                      (parameter_name, parameters_dict[parameter_name]))
            print("Total %d variables, %s M parameters" %
                  (len(parameters_dict.keys()), "{:,}".format(
                      total_parameters / 1000000)))

            # Make feed dict
            feed_dict = {
                inputs_pl: inputs,
                inputs_seq_len_pl: inputs_seq_len,
                keep_prob_pl: 0.9
            }

            with tf.Session() as sess:
                # Initialize parameters
                sess.run(init_op)

                # Make prediction
                if encoder_type in ['multitask_blstm', 'multitask_lstm']:
                    encoder_outputs, final_state, hidden_states_sub, final_state_sub = sess.run(
                        [
                            hidden_states_op, final_state_op,
                            hidden_states_sub_op, final_state_sub_op
                        ],
                        feed_dict=feed_dict)
                elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                    encoder_outputs = sess.run(hidden_states_op,
                                               feed_dict=feed_dict)
                else:
                    encoder_outputs, final_state = sess.run(
                        [hidden_states_op, final_state_op],
                        feed_dict=feed_dict)

                # Convert always to batch-major
                if time_major:
                    encoder_outputs = encoder_outputs.transpose(1, 0, 2)

                if encoder_type in [
                        'blstm', 'bgru', 'vgg_blstm', 'multitask_blstm',
                        'cldnn_wang'
                ]:
                    if encoder_type != 'cldnn_wang':
                        self.assertEqual(
                            (batch_size, frame_num, encoder.num_units * 2),
                            encoder_outputs.shape)

                    if encoder_type != 'bgru':
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[0].c.shape)
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[0].h.shape)
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[1].c.shape)
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[1].h.shape)

                        if encoder_type == 'multitask_blstm':
                            self.assertEqual(
                                (batch_size, frame_num, encoder.num_units * 2),
                                hidden_states_sub.shape)
                            self.assertEqual((batch_size, encoder.num_units),
                                             final_state_sub[0].c.shape)
                            self.assertEqual((batch_size, encoder.num_units),
                                             final_state_sub[0].h.shape)
                            self.assertEqual((batch_size, encoder.num_units),
                                             final_state_sub[1].c.shape)
                            self.assertEqual((batch_size, encoder.num_units),
                                             final_state_sub[1].h.shape)
                    else:
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[0].shape)
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[1].shape)

                elif encoder_type in [
                        'lstm', 'gru', 'vgg_lstm', 'multitask_lstm'
                ]:
                    self.assertEqual(
                        (batch_size, frame_num, encoder.num_units),
                        encoder_outputs.shape)

                    if encoder_type != 'gru':
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[0].c.shape)
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[0].h.shape)

                        if encoder_type == 'multitask_lstm':
                            self.assertEqual(
                                (batch_size, frame_num, encoder.num_units),
                                hidden_states_sub.shape)
                            self.assertEqual((batch_size, encoder.num_units),
                                             final_state_sub[0].c.shape)
                            self.assertEqual((batch_size, encoder.num_units),
                                             final_state_sub[0].h.shape)
                    else:
                        self.assertEqual((batch_size, encoder.num_units),
                                         final_state[0].shape)

                elif encoder_type in ['vgg_wang', 'resnet_wang', 'cnn_zhang']:
                    self.assertEqual(3, len(encoder_outputs.shape))
                    self.assertEqual((batch_size, frame_num),
                                     encoder_outputs.shape[:2])