示例#1
0
文件: evaluate.py 项目: wxy1988/ASR
class Evaluator(object):
    """
    Evaluate the model.
    """
    def __init__(self):
        pass

    def init_from_config(self, config):
        self.model = eval(config.model)(config, config.test.num_gpus)
        self.model.build_test_model()

        sess_config = tf.ConfigProto()
        sess_config.gpu_options.allow_growth = True
        sess_config.allow_soft_placement = True
        self.sess = tf.Session(config=sess_config, graph=self.model.graph)
        if is_debug:
            self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
        # Restore model.
        self.model.saver.restore(self.sess,
                                 tf.train.latest_checkpoint(config.model_dir))

        self.data_reader = DataReader(config)

    def init_from_existed(self, model, sess, data_reader):
        assert model.graph == sess.graph
        self.sess = sess
        self.model = model
        self.data_reader = data_reader

    def beam_search(self, X):
        return self.sess.run(self.model.prediction,
                             feed_dict=expand_feed_dict(
                                 {self.model.src_pls: X}))

    def loss(self, X, Y):
        return self.sess.run(self.model.loss_sum,
                             feed_dict=expand_feed_dict({
                                 self.model.src_pls: X,
                                 self.model.dst_pls: Y
                             }))

    def translate(self, src_path, output_path, batch_size):
        logging.info('Translate %s.' % src_path)
        tmp = output_path + '.tmp'
        fd = codecs.open(tmp, 'w', 'utf8')
        count = 0
        token_count = 0
        start = time.time()
        for X, uttids in self.data_reader.get_test_batches(
                src_path, batch_size):
            Y = self.beam_search(X)
            sents = self.data_reader.indices_to_words(Y)
            assert len(X) == len(sents)
            for sent, uttid in zip(sents, uttids):
                print(uttid + '\t' + sent, file=fd)
            count += len(X)
            token_count += np.sum(np.not_equal(Y, 3))  # 3: </s>
            time_span = time.time() - start
            logging.info(
                '{0} sentences ({1} tokens) processed in {2:.2f} minutes (speed: {3:.4f} sec/token).'
                .format(count, token_count, time_span / 60,
                        time_span / token_count))
        fd.close()
        # Remove BPE flag, if have.
        os.system("sed -r 's/(@@ )|(@@ ?$)//g' %s > %s" % (tmp, output_path))
        os.remove(tmp)
        logging.info('The result file was saved in %s.' % output_path)

    def ppl(self, src_path, dst_path, batch_size):
        logging.info('Calculate PPL for %s and %s.' % (src_path, dst_path))
        token_count = 0
        loss_sum = 0
        for batch in self.data_reader.get_test_batches_with_target(
                src_path, dst_path, batch_size):
            X, Y = batch
            loss_sum += self.loss(X, Y)
            token_count += np.sum(np.greater(Y, 0))
        # Compute PPL
        ppl = np.exp(loss_sum / token_count)
        logging.info('PPL: %.4f' % ppl)
        return ppl

    def evaluate(self, batch_size, **kargs):
        """Evaluate the model on dev set."""
        src_path = kargs['src_path']
        output_path = kargs['output_path']
        cmd = kargs['cmd'] if 'cmd' in kargs else\
            "perl multi-bleu.perl {ref} < {output} 2>/dev/null | awk '{{print($3)}}' | awk -F, '{{print $1}}'"
        self.translate(src_path, output_path, batch_size)
        # if 'ref_path' in kargs:
        #     ref_path = kargs['ref_path']
        #     bleu = commands.getoutput(cmd.format(**{'ref': ref_path, 'output': output_path}))
        #     logging.info('BLEU: {}'.format(bleu))
        #     return float(bleu)
        # if 'dst_path' in kargs:
        #     self.ppl(src_path, kargs['dst_path'], batch_size)
        return None
示例#2
0
class Evaluator(object):
    """
    Evaluate the model.
    """
    def __init__(self):
        pass

    def init_from_config(self, config):

        logger = logging.getLogger('')

        self.model = eval(config.model)(config, config.test.num_gpus)
        self.model.build_test_model()

        # Print the number of total parameters
        print_num_of_total_parameters()

        sess_config = tf.ConfigProto()
        sess_config.gpu_options.allow_growth = True
        sess_config.allow_soft_placement = True
        self.sess = tf.Session(config=sess_config, graph=self.model.graph)
        # Restore model.
        self.model.saver.restore(self.sess, tf.train.latest_checkpoint(config.model_dir))

        self.data_reader = DataReader(config)

    def init_from_existed(self, model, sess, data_reader):
        assert model.graph == sess.graph
        self.sess = sess
        self.model = model
        self.data_reader = data_reader

    def beam_search(self, X):
        return self.sess.run(self.model.prediction, feed_dict=expand_feed_dict({self.model.src_pls: X}))

    def beam_search_label(self, X, Y, Z, X_lens):
        return self.sess.run([self.model.prediction, self.model.prediction_label], feed_dict=expand_feed_dict({self.model.src_pls: X, self.model.dst_pls: Y, self.model.label_pls: Z, self.model.src_len_pls: X_lens}))

    def loss(self, X, Y):
        return self.sess.run(self.model.loss_sum, feed_dict=expand_feed_dict({self.model.src_pls: X, self.model.dst_pls: Y}))

    def loss_label(self, X, Y, Z):
        return self.sess.run(self.model.loss_sum, feed_dict=expand_feed_dict({self.model.src_pls: X, self.model.dst_pls: Y, self.model.label_pls: Z}))

    def translate(self, src_path, dst_path, lbl_path, output_path, output_label_path, batch_size):
        logging.info('Translate %s.' % src_path)
        _, tmp = mkstemp()
        fd = codecs.open(tmp, 'w', 'utf8')

        _, tmp_label = mkstemp()
        fd_label = codecs.open(tmp_label, 'w', 'utf8')

        count = 0
        token_count = 0
        start = time.time()
        for X, ref, label, src_lens in self.data_reader.get_test_batches_with_target_with_label(src_path, dst_path, lbl_path, batch_size):
            Y, Z = self.beam_search_label(X, ref, label, src_lens)
            sents = self.data_reader.indices_to_words(Y, src_lens)
            assert len(X) == len(sents)
            for sent in sents:
                print(sent, file=fd)
            count += len(X)
            token_count += np.sum(np.not_equal(Y, 3))  # 3: </s>
            time_span = time.time() - start
            logging.info('{0} sentences ({1} tokens) processed in {2:.2f} minutes (speed: {3:.4f} sec/token).'.
                         format(count, token_count, time_span / 60, time_span / token_count))

            # Save the prediction of label
            sents_label = self.data_reader.indices_to_words(Z, src_lens, o='lbl')
            assert len(X) == len(sents_label)
            for sent in sents_label:
                print(sent, file=fd_label)

        fd.close()

        # Remove BPE flag, if have.
        os.system("sed -r 's/(@@ )|(@@ ?$)//g' %s > %s" % (tmp, output_path))
        os.remove(tmp)
        logging.info('The result file was saved in %s.' % output_path)

        fd_label.close()
        os.system("sed -r 's/(@@ )|(@@ ?$)//g' %s > %s" % (tmp_label, output_label_path))
        os.remove(tmp_label)
        logging.info('The label file was saved in %s.' % output_label_path)

    def ppl(self, src_path, dst_path, batch_size):
        logging.info('Calculate PPL for %s and %s.' % (src_path, dst_path))
        token_count = 0
        loss_sum = 0
        for batch in self.data_reader.get_test_batches_with_target(src_path, dst_path, batch_size):
            X, Y = batch
            loss_sum += self.loss(X, Y)
            token_count += np.sum(np.greater(Y, 0))
        # Compute PPL
        ppl = np.exp(loss_sum / token_count)
        logging.info('PPL: %.4f' % ppl)
        return ppl

    def fscore(self, lbl_path, output_label_path):
        logging.info('Calculate P/R/F for %s and %s.' % (lbl_path, output_label_path))
        ref_file = codecs.open(lbl_path, 'r', 'utf8')
        pred_file = codecs.open(output_label_path, 'r', 'utf8')

        tp, fp, fn = 1, 1, 1
        err = 0
        # assert len(target) == len(prediction)
        line = 0
        for ref, pred in zip(ref_file, pred_file):
            line += 1
            if len(ref) != len(pred):
                # print(line)
                err += 1
                continue
            for x, y in zip(ref, pred):
                if x == y and x == 'E':
                    tp += 1
                elif y == 'E':
                    fp += 1
                elif x == 'E':
                    fn += 1
                else:
                    pass
        print('tp:{}, fp:{}, fn:{}, err:{}'.format(tp, fp, fn, err))
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        fscore = (2 * precision * recall / (precision + recall))

        ref_file.close()
        pred_file.close()

        logging.info('precision: %.4f' % precision)
        logging.info('recall: %.4f' % recall)
        logging.info('fscore: %.4f' % fscore)
        return precision, recall, fscore

    def evaluate(self, batch_size, **kargs):
        """Evaluate the model on dev set."""
        src_path = kargs['src_path']
        dst_path = kargs['ref_path']
        lbl_path = kargs['label_path']
        output_path = kargs['output_path']
        output_label_path = kargs['output_label_path']
        cmd = kargs['cmd'] if 'cmd' in kargs else\
            "perl multi-bleu.perl {ref} < {output} 2>/dev/null | awk '{{print($3)}}' | awk -F, '{{print $1}}'"
        self.translate(src_path, dst_path, lbl_path, output_path, output_label_path, batch_size)

        if 'dst_path' in kargs:
            self.ppl(src_path, kargs['dst_path'], batch_size)

        # calculate the fscore of label result
        if 'label_path' in kargs:
            precision, recall, f_score = self.fscore(lbl_path, output_label_path)
            return float(f_score)

        return None
示例#3
0
class Evaluator(object):
    """
    Evaluate the model.
    """
    def __init__(self):
        pass

    def init_from_config(self, config):
        self.model = eval(config.model)(config, config.test.num_gpus)
        self.model.build_test_model()

        sess_config = tf.ConfigProto()
        sess_config.gpu_options.allow_growth = True
        sess_config.allow_soft_placement = True
        self.sess = tf.Session(config=sess_config)

        # Restore model.
        try:
            tf.train.Saver().restore(
                self.sess, tf.train.latest_checkpoint(config.model_dir))
        except tf.errors.NotFoundError:
            roll_back_to_previous_version(config)
            tf.train.Saver().restore(
                self.sess, tf.train.latest_checkpoint(config.model_dir))

        self.data_reader = DataReader(config)

    def init_from_frozen_graphdef(self, config):
        frozen_graph_path = os.path.join(config.model_dir,
                                         'freeze_graph_test.py')
        # If the file doesn't existed, create it.
        if not os.path.exists(frozen_graph_path):
            logging.warning(
                'The frozen graph does not existed, use \'init_from_config\' instead'
                'and create a frozen graph for next use.')
            self.init_from_config(config)
            saver = tf.train.Saver()
            save_dir = '/tmp/graph-{}'.format(os.getpid())
            os.mkdir(save_dir)
            save_path = '{}/ckpt'.format(save_dir)
            saver.save(sess=self.sess, save_path=save_path)

            with tf.Session(graph=tf.Graph()) as sess:
                clear_devices = True
                output_node_names = ['loss_sum', 'predictions']
                # We import the meta graph in the current default Graph
                saver = tf.train.import_meta_graph(save_path + '.meta',
                                                   clear_devices=clear_devices)

                # We restore the weights
                saver.restore(sess, save_path)

                # We use a built-in TF helper to export variables to constants
                output_graph_def = tf.graph_util.convert_variables_to_constants(
                    sess,  # The session is used to retrieve the weights
                    tf.get_default_graph().as_graph_def(
                    ),  # The graph_def is used to retrieve the nodes
                    output_node_names  # The output node names are used to select the useful nodes
                )

                # Finally we serialize and dump the output graph to the filesystem
                with tf.gfile.GFile(frozen_graph_path, "wb") as f:
                    f.write(output_graph_def.SerializeToString())
                    logging.info("%d ops in the final graph." %
                                 len(output_graph_def.node))

                # Remove temp files.
                os.system('rm -rf ' + save_dir)
        else:
            sess_config = tf.ConfigProto()
            sess_config.gpu_options.allow_growth = True
            sess_config.allow_soft_placement = True
            self.sess = tf.Session(config=sess_config)
            self.data_reader = DataReader(config)

            # We load the protobuf file from the disk and parse it to retrieve the
            # unserialized graph_def
            with tf.gfile.GFile(frozen_graph_path, "rb") as f:
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())

            # Import the graph_def into current the default graph.
            tf.import_graph_def(graph_def)
            graph = tf.get_default_graph()
            self.model = AttrDict()

            def collect_placeholders(prefix):
                ret = []
                idx = 0
                while True:
                    try:
                        ret.append(
                            graph.get_tensor_by_name('import/{}_{}:0'.format(
                                prefix, idx)))
                        idx += 1
                    except KeyError:
                        return tuple(ret)

            self.model['src_pls'] = collect_placeholders('src_pl')
            self.model['dst_pls'] = collect_placeholders('dst_pl')
            self.model['predictions'] = graph.get_tensor_by_name(
                'import/predictions:0')

    def init_from_existed(self, model, sess, data_reader):
        self.sess = sess
        self.model = model
        self.data_reader = data_reader

    def beam_search(self, X):
        return self.sess.run(self.model.predictions,
                             feed_dict=expand_feed_dict(
                                 {self.model.src_pls: X}))

    def loss(self, X, Y):
        return self.sess.run(self.model.loss_sum,
                             feed_dict=expand_feed_dict({
                                 self.model.src_pls: X,
                                 self.model.dst_pls: Y
                             }))

    def translate(self, src_path, output_path, batch_size):
        logging.info('Translate %s.' % src_path)
        _, tmp = mkstemp()
        fd = codecs.open(tmp, 'w', 'utf8')
        count = 0
        token_count = 0
        epsilon = 1e-6
        start = time.time()
        for X in self.data_reader.get_test_batches(src_path, batch_size):
            Y = self.beam_search(X)
            Y = Y[:len(X)]
            sents = self.data_reader.indices_to_words(Y)
            assert len(X) == len(sents)
            for sent in sents:
                print(sent, file=fd)
            count += len(X)
            token_count += np.sum(np.not_equal(Y, 3))  # 3: </s>
            time_span = time.time() - start
            logging.info(
                '{0} sentences ({1} tokens) processed in {2:.2f} minutes (speed: {3:.4f} sec/token).'
                .format(count, token_count, time_span / 60,
                        time_span / (token_count + epsilon)))
        fd.close()
        # Remove BPE flag, if have.
        os.system("sed -r 's/(@@ )|(@@ ?$)//g' %s > %s" % (tmp, output_path))
        os.remove(tmp)
        logging.info('The result file was saved in %s.' % output_path)

    def ppl(self, src_path, dst_path, batch_size):
        logging.info('Calculate PPL for %s and %s.' % (src_path, dst_path))
        token_count = 0
        loss_sum = 0
        for batch in self.data_reader.get_test_batches_with_target(
                src_path, dst_path, batch_size):
            X, Y = batch
            loss_sum += self.loss(X, Y)
            token_count += np.sum(np.greater(Y, 0))
        # Compute PPL
        ppl = np.exp(loss_sum / token_count)
        logging.info('PPL: %.4f' % ppl)
        return ppl

    def evaluate(self, batch_size, **kargs):
        """Evaluate the model on dev set."""
        src_path = kargs['src_path']
        output_path = kargs['output_path']
        cmd = kargs['cmd'] if 'cmd' in kargs else\
            "perl multi-bleu.perl {ref} < {output} 2>/dev/null | awk '{{print($3)}}' | awk -F, '{{print $1}}'"
        cmd = cmd.strip()
        logging.info('Evaluation command: ' + cmd)
        self.translate(src_path, output_path, batch_size)
        bleu = None
        if 'ref_path' in kargs:
            ref_path = kargs['ref_path']
            try:
                bleu = commands.getoutput(
                    cmd.format(**{
                        'ref': ref_path,
                        'output': output_path
                    }))
                bleu = float(bleu)
            except ValueError, e:
                logging.warning(
                    'An error raised when calculate BLEU: {}'.format(e))
                bleu = 0
            logging.info('BLEU: {}'.format(bleu))
        if 'dst_path' in kargs:
            self.ppl(src_path, kargs['dst_path'], batch_size)
        return bleu