示例#1
0
    def insight(self,
                data,
                spklist,
                batch_type="softmax",
                output_embeddings=False,
                aux_data=None):
        """Just use to debug the network
        """
        self.sess.run(tf.global_variables_initializer())
        self.sess.run(tf.local_variables_initializer())
        assert batch_type == "softmax" or batch_type == "end2end", "The batch_type can only be softmax or end2end"

        embeddings_val = None
        labels_val = None

        self.load()

        if output_embeddings:
            # If we want to output embeddings, the features should be loaded in order
            data_loader = KaldiDataSeqQueue(
                data,
                spklist,
                num_parallel=2,
                max_qsize=10,
                batch_size=self.params.num_speakers_per_batch *
                self.params.num_segments_per_speaker,
                min_len=self.params.min_segment_len,
                max_len=self.params.max_segment_len,
                shuffle=False)
            data_loader.start()

            tf.logging.info("Generate valid embeddings.")
            # In this mode, the embeddings and labels will be saved and output. It needs more memory and takes longer
            # to process these values.
            while True:
                try:
                    features, labels = data_loader.fetch()
                    valid_emb_val, valid_labels_val, endpoints_val = self.sess.run(
                        [self.embeddings, self.valid_labels, self.endpoints],
                        feed_dict={
                            self.valid_features: features,
                            self.valid_labels: labels
                        })

                    # acc = np.sum(np.equal(np.argmax(endpoints_val['logits'], axis=1), labels, dtype=np.float)) / float(
                    #     labels.shape[0])
                    # print("Acc: %f" % acc)

                    # Save the embeddings and labels
                    if embeddings_val is None:
                        embeddings_val = valid_emb_val
                        labels_val = valid_labels_val
                    else:
                        embeddings_val = np.concatenate(
                            (embeddings_val, valid_emb_val), axis=0)
                        labels_val = np.concatenate(
                            (labels_val, valid_labels_val), axis=0)
                except DataOutOfRange:
                    break
            data_loader.stop()

        if batch_type == "softmax":
            data_loader = KaldiDataSeqQueue(
                data,
                spklist,
                num_parallel=2,
                max_qsize=10,
                batch_size=self.params.num_speakers_per_batch *
                self.params.num_segments_per_speaker * 10,
                min_len=self.params.min_segment_len,
                max_len=self.params.max_segment_len,
                shuffle=True)
        elif batch_type == "end2end":
            # The num_valid_speakers_per_batch and num_valid_segments_per_speaker are only required when
            # End2End loss is used. Since we switch the loss function to softmax generalized e2e loss
            # when the e2e loss is used.
            assert "num_valid_speakers_per_batch" in self.params.dict and "num_valid_segments_per_speaker" in self.params.dict, \
                "Valid parameters should be set if E2E loss is selected"
            data_loader = KaldiDataRandomQueue(
                data,
                spklist,
                num_parallel=2,
                max_qsize=10,
                num_speakers=self.params.num_valid_speakers_per_batch,
                num_segments=self.params.num_valid_segments_per_speaker,
                min_len=self.params.min_segment_len,
                max_len=self.params.max_segment_len,
                shuffle=True)
        else:
            raise ValueError

        data_loader.start()

        while True:
            try:
                features, labels = data_loader.fetch()
                _, endpoints_val = self.sess.run(
                    [self.valid_ops["valid_loss_op"], self.endpoints],
                    feed_dict={
                        self.valid_features: features,
                        self.valid_labels: labels
                    })
            except DataOutOfRange:
                break
        data_loader.stop()
        loss = self.sess.run(self.valid_ops["valid_loss"])
        tf.logging.info(
            "Shorter segments are used to test the valid loss (%d-%d)" %
            (self.params.min_segment_len, self.params.max_segment_len))
        tf.logging.info("Loss: %f" % loss)

        # while True:
        #     try:
        #         features, labels = data_loader.fetch()
        #         valid_ops, endpoints_val = self.sess.run([self.valid_ops, self.endpoints], feed_dict={self.valid_features: features,
        #                                                                                                          self.valid_labels: labels})
        #         loss = valid_ops["valid_loss"]
        #     except DataOutOfRange:
        #         break
        # data_loader.stop()
        # tf.logging.info("Loss: %f" % loss)

        acc = np.sum(
            np.equal(np.argmax(endpoints_val['logits'], axis=1),
                     labels,
                     dtype=np.float)) / float(labels.shape[0])
        print("Acc: %f" % acc)

        import pdb
        pdb.set_trace()
        # from model.test_utils import softmax
        # with tf.variable_scope("softmax", reuse=True):
        #     test = tf.get_variable("output/kernel")
        #     test_val = self.sess.run(test)
        return loss, embeddings_val, labels_val
示例#2
0
    embeddings_val = None
    labels_val = None
    data_loader = KaldiDataSeqQueue(args.data_dir,
                                    args.data_spklist,
                                    num_parallel=1,
                                    max_qsize=10,
                                    batch_size=params.num_speakers_per_batch *
                                    params.num_segments_per_speaker,
                                    min_len=params.min_segment_len,
                                    max_len=params.max_segment_len,
                                    shuffle=False)
    data_loader.start()
    while True:
        try:
            features, labels = data_loader.fetch()
            valid_emb_val, valid_labels_val, _ = trainer.sess.run(
                [
                    trainer.embeddings, trainer.valid_labels,
                    trainer.valid_ops["valid_loss_op"]
                ],
                feed_dict={
                    trainer.valid_features: features,
                    trainer.valid_labels: labels,
                    trainer.global_step: curr_step
                })
            # Save the embeddings and labels
            if embeddings_val is None:
                embeddings_val = valid_emb_val
                labels_val = valid_labels_val
            else:
示例#3
0
    def valid(self,
              data,
              spklist,
              batch_type="softmax",
              output_embeddings=False,
              aux_data=None):
        """Evaluate on the validation set

        Args:
            data: The training data directory.
            spklist: The spklist is a file map speaker name to the index.
            batch_type: `softmax` or `end2end`. The batch is `softmax-like` or `end2end-like`.
                        If the batch is `softmax-like`, each sample are from different speakers;
                        if the batch is `end2end-like`, the samples are from N speakers with M segments per speaker.
            output_embeddings: Set True to output the corresponding embeddings and labels of the valid set.
                               If output_embeddings, an additional valid metric (e.g. EER) should be computed outside
                               the function.
            aux_data: The auxiliary data directory.

        :return: valid_loss, embeddings and labels (None if output_embeddings is False).
        """
        # Initialization will reset all the variables in the graph.
        # The local variables are also need to be initialized for metrics function.
        self.sess.run(tf.global_variables_initializer())
        self.sess.run(tf.local_variables_initializer())
        assert batch_type == "softmax" or batch_type == "end2end", "The batch_type can only be softmax or end2end"

        curr_step = 0
        # Load the model. The valid function can only be called after training (of course...)
        if os.path.isfile(os.path.join(self.model, "checkpoint")):
            curr_step = self.load()
        else:
            tf.logging.info(
                "[Warning] Cannot find model in %s. Random initialization is used in validation."
                % self.model)

        embeddings_val = None
        labels_val = None
        num_batches = 0

        if output_embeddings:
            # If we want to output embeddings, the features should be loaded in order
            data_loader = KaldiDataSeqQueue(
                data,
                spklist,
                num_parallel=2,
                max_qsize=10,
                batch_size=self.params.num_speakers_per_batch *
                self.params.num_segments_per_speaker,
                min_len=self.params.min_segment_len,
                max_len=self.params.max_segment_len,
                shuffle=False)
            data_loader.start()

            tf.logging.info("Generate valid embeddings.")
            # In this mode, the embeddings and labels will be saved and output. It needs more memory and takes longer
            # to process these values.
            while True:
                try:
                    if num_batches % 100 == 0:
                        tf.logging.info("valid step: %d" % num_batches)
                    features, labels = data_loader.fetch()
                    valid_emb_val, valid_labels_val = self.sess.run(
                        [self.embeddings, self.valid_labels],
                        feed_dict={
                            self.valid_features: features,
                            self.valid_labels: labels,
                            self.global_step: curr_step
                        })
                    # Save the embeddings and labels
                    if embeddings_val is None:
                        embeddings_val = valid_emb_val
                        labels_val = valid_labels_val
                    else:
                        embeddings_val = np.concatenate(
                            (embeddings_val, valid_emb_val), axis=0)
                        labels_val = np.concatenate(
                            (labels_val, valid_labels_val), axis=0)
                    num_batches += 1
                except DataOutOfRange:
                    break
            data_loader.stop()

        if batch_type == "softmax":
            data_loader = KaldiDataSeqQueue(
                data,
                spklist,
                num_parallel=2,
                max_qsize=10,
                batch_size=self.params.num_speakers_per_batch *
                self.params.num_segments_per_speaker,
                min_len=self.params.min_segment_len,
                max_len=self.params.max_segment_len,
                shuffle=True)
        elif batch_type == "end2end":
            # The num_valid_speakers_per_batch and num_valid_segments_per_speaker are only required when
            # End2End loss is used. Since we switch the loss function to softmax generalized e2e loss
            # when the e2e loss is used.
            assert "num_valid_speakers_per_batch" in self.params.dict and "num_valid_segments_per_speaker" in self.params.dict, \
                "Valid parameters should be set if E2E loss is selected"
            data_loader = KaldiDataRandomQueue(
                data,
                spklist,
                num_parallel=2,
                max_qsize=10,
                num_speakers=self.params.num_valid_speakers_per_batch,
                num_segments=self.params.num_valid_segments_per_speaker,
                min_len=self.params.min_segment_len,
                max_len=self.params.max_segment_len,
                shuffle=True)
        else:
            raise ValueError

        data_loader.start()
        num_batches = 0
        for _ in range(self.params.valid_max_iterations):
            try:
                if num_batches % 100 == 0:
                    tf.logging.info("valid step: %d" % num_batches)
                features, labels = data_loader.fetch()
                _ = self.sess.run(self.valid_ops["valid_loss_op"],
                                  feed_dict={
                                      self.valid_features: features,
                                      self.valid_labels: labels,
                                      self.global_step: curr_step
                                  })
                num_batches += 1
            except DataOutOfRange:
                break
        data_loader.stop()

        loss, summary = self.sess.run(
            [self.valid_ops["valid_loss"], self.valid_summary])
        # We only save the summary for the last batch.
        self.valid_summary_writer.add_summary(summary, curr_step)
        # The valid loss is averaged over all the batches.
        tf.logging.info("[Validation %d batches] valid loss: %f" %
                        (num_batches, loss))

        # The output embeddings and labels can be used to compute EER or other metrics
        return loss, embeddings_val, labels_val