Beispiel #1
0
    def add_output(self, name, labels, **kwargs):
        """Add metadata for output tensor in order."""
        if isinstance(labels, list):
            default_locale = None
            labels = collections.OrderedDict([(default_locale, labels)])
            return self.add_output(name, labels, **kwargs)

        label_files = []
        if isinstance(labels, collections.OrderedDict):
            for locale, label_list in labels.items():
                full_path = os.path.join(
                    self._temp_folder.name,
                    '{}_labels_{}.txt'.format(name, locale or 'default'))
                model_util.export_labels(full_path, label_list)
                label_files.append(
                    md_info.LabelFileMd(file_path=full_path, locale=locale))
        else:
            raise ValueError(
                '`labels` should be either a list of labels or an ordered dict mapping `locale` -> list of labels. got: {}'
                .format(labels))

        idx = len(self._outputs)
        self._outputs.append(
            md_info.ClassificationTensorMd(
                name=name,
                label_files=label_files,
                tensor_type=writer_utils.get_output_tensor_types(
                    self._model)[idx],
                **kwargs))
Beispiel #2
0
    def _export_metadata(self, tflite_filepath, index_to_label,
                         export_metadata_json_file):
        """Export TFLite metadata."""
        with tempfile.TemporaryDirectory() as temp_dir:
            # Prepare metadata
            with open(tflite_filepath, 'rb') as f:
                model_buffer = f.read()

            general_md = md_info.GeneralMd(name=self._MODEL_NAME,
                                           description=self._MODEL_DESCRIPTION,
                                           version=self._MODEL_VERSION,
                                           author=self._MODEL_AUTHOR,
                                           licenses=self._MODEL_LICENSES)
            input_md = md_info.InputAudioTensorMd(self._INPUT_NAME,
                                                  self._INPUT_DESCRIPTION,
                                                  self._SAMPLE_RATE,
                                                  self._CHANNELS)

            # Save label files.
            custom_label_filepath = os.path.join(temp_dir,
                                                 self._CUSTOM_LABEL_FILE)
            self._export_labels(custom_label_filepath, index_to_label)

            custom_output_md = md_info.ClassificationTensorMd(
                name=self._CUSTOM_OUTPUT_NAME,
                description=self._CUSTOM_OUTPUT_DESCRIPTION,
                label_files=[
                    md_info.LabelFileMd(file_path=os.path.join(
                        temp_dir, self._CUSTOM_LABEL_FILE))
                ],
                tensor_type=writer_utils.get_output_tensor_types(
                    model_buffer)[-1],
                score_calibration_md=None)

            if self._keep_yamnet_and_custom_heads:
                yamnet_label_filepath = os.path.join(temp_dir,
                                                     self._YAMNET_LABEL_FILE)
                self._export_labels(yamnet_label_filepath,
                                    self._yamnet_labels())

                yamnet_output_md = md_info.ClassificationTensorMd(
                    name=self._YAMNET_OUTPUT_NAME,
                    description=self._YAMNET_OUTPUT_DESCRIPTION,
                    label_files=[
                        md_info.LabelFileMd(file_path=os.path.join(
                            temp_dir, self._YAMNET_LABEL_FILE))
                    ],
                    tensor_type=writer_utils.get_output_tensor_types(
                        model_buffer)[0],
                    score_calibration_md=None)
                output_md = [yamnet_output_md, custom_output_md]
            else:
                output_md = [custom_output_md]

            # Populate metadata
            writer = md_writer.MetadataWriter.create_from_metadata_info_for_multihead(
                model_buffer=model_buffer,
                general_md=general_md,
                input_md=input_md,
                output_md_list=output_md)

            output_model = writer.populate()

            with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
                f.write(output_model)

            if export_metadata_json_file:
                metadata_json = writer.get_metadata_json()
                export_json_file = os.path.splitext(
                    tflite_filepath)[0] + '.json'
                with open(export_json_file, 'w') as f:
                    f.write(metadata_json)