Args: train_data: A instance of audio_dataloader.DataLoader class. model_spec: Specification for the model. validation_data: Validation DataLoader. If None, skips validation process. batch_size: Number of samples per training step. If `use_hub_library` is False, it represents the base learning rate when train batch size is 256 and it's linear to the batch size. epochs: Number of epochs for training. model_dir: The location of the model checkpoint files. do_train: Whether to run training. train_whole_model: Boolean. By default, only the classification head is trained. When True, the base model is also trained. Returns: An instance based on AudioClassifier. """ if not isinstance(model_spec, audio_spec.BaseSpec): model_spec = model_spec.get(model_spec, model_dir=model_dir) task = cls(model_spec, train_data.index_to_label, shuffle=True, train_whole_model=train_whole_model) if do_train: task.train(train_data, validation_data, epochs, batch_size) return task # Shortcut function. create = AudioClassifier.create mm_export('audio_classifier.create').export_constant(__name__, 'create')
train_data: Training data. model_spec: Specification for the model. batch_size: Batch size for training. epochs: Number of epochs for training. shuffle: Whether the data should be shuffled. do_train: Whether to run training. Returns: An instance based on QuestionAnswer. """ model_spec = ms.get(model_spec) if compat.get_tf_behavior() not in model_spec.compat_tf_versions: raise ValueError( 'Incompatible versions. Expect {}, but got {}.'.format( model_spec.compat_tf_versions, compat.get_tf_behavior())) model = cls(model_spec, shuffle=shuffle) if do_train: tf.compat.v1.logging.info('Retraining the models...') model.train(train_data, epochs, batch_size) else: model.create_model() return model # Shortcut function. create = QuestionAnswer.create mm_export('question_answer.create').export_constant(__name__, 'create')
compat_tf_versions) self.name = name if input_image_shape is None: input_image_shape = [224, 224] self.input_image_shape = input_image_shape mobilenet_v2_spec = functools.partial( ImageModelSpec, uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4', compat_tf_versions=2, name='mobilenet_v2') mobilenet_v2_spec.__doc__ = util.wrap_doc(ImageModelSpec, 'Creates MobileNet v2 model spec.') mm_export('image_classifier.MobileNetV2Spec').export_constant( __name__, 'mobilenet_v2_spec') resnet_50_spec = functools.partial( ImageModelSpec, uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4', compat_tf_versions=2, name='resnet_50') resnet_50_spec.__doc__ = util.wrap_doc(ImageModelSpec, 'Creates ResNet 50 model spec.') mm_export('image_classifier.Resnet50Spec').export_constant( __name__, 'resnet_50_spec') efficientnet_lite0_spec = functools.partial( ImageModelSpec, uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2', compat_tf_versions=[1, 2],
spec = model_spec model_spec = spec(**model_spec_options) # Use model_dir or a temp folder to store intermediate checkpoints, etc. if model_dir is None: model_dir = tempfile.mkdtemp() recommendation = cls(model_spec, model_dir=model_dir, shuffle=shuffle, max_history_length=max_history_length, learning_rate=learning_rate, gradient_clip_norm=gradient_clip_norm) if do_train: tf.compat.v1.logging.info('Training recommendation model...') recommendation.train(train_data, validation_data, batch_size=batch_size, steps_per_epoch=steps_per_epoch, epochs=epochs) else: recommendation.create_model(do_train=False) return recommendation # Shortcut function. create = Recommendation.create mm_export('recommendation.create').export_constant(__name__, 'create')
mobilebert_classifier_spec = functools.partial( BertClassifierModelSpec, uri= 'https://tfhub.dev/google/mobilebert/uncased_L-24_H-128_B-512_A-4_F-4_OPT/1', is_tf2=False, distribution_strategy='off', name='MobileBert', default_batch_size=48, ) mobilebert_classifier_spec.__doc__ = util.wrap_doc( BertClassifierModelSpec, 'Creates MobileBert model spec for the text classification task. See also: `tflite_model_maker.text_classifier.BertClassifierSpec`.' ) mm_export('text_classifier.MobileBertClassifierSpec').export_constant( __name__, 'mobilebert_classifier_spec') mobilebert_qa_spec = functools.partial( BertQAModelSpec, uri= 'https://tfhub.dev/google/mobilebert/uncased_L-24_H-128_B-512_A-4_F-4_OPT/1', is_tf2=False, distribution_strategy='off', learning_rate=4e-05, name='MobileBert', default_batch_size=32, ) mobilebert_qa_spec.__doc__ = util.wrap_doc( BertQAModelSpec, 'Creates MobileBert model spec for the question answer task. See also: `tflite_model_maker.question_answer.BertQaSpec`.' )
'item_vocab_size': item_vocab_size, 'num_predictions': num_predictions, 'conv_num_filter_ratios': conv_num_filter_ratios, 'conv_kernel_size': conv_kernel_size, 'lstm_num_units': lstm_num_units, 'eval_top_k': eval_top_k, } def create_model(self): """Creates recommendation model based on params. Returns: Keras model. """ return _rm.RecommendationModel(self.params) recommendation_bow_spec = functools.partial(RecommendationSpec, encoder_type='bow') recommendation_cnn_spec = functools.partial(RecommendationSpec, encoder_type='cnn') recommendation_rnn_spec = functools.partial(RecommendationSpec, encoder_type='rnn') mm_export('recommendation.BowSpec').export_constant(__name__, 'recommendation_bow_spec') mm_export('recommendation.CnnSpec').export_constant(__name__, 'recommendation_cnn_spec') mm_export('recommendation.RnnSpec').export_constant(__name__, 'recommendation_rnn_spec')
AUDIO_CLASSIFICATION_MODELS = [ 'audio_browser_fft', 'audio_teachable_machine', 'audio_yamnet' ] RECOMMENDATION_MODELS = [ 'recommendation_bow', 'recommendation_rnn', 'recommendation_cnn', ] OBJECT_DETECTION_MODELS = [ 'efficientdet_lite0', 'efficientdet_lite1', 'efficientdet_lite2', 'efficientdet_lite3', 'efficientdet_lite4', ] mm_export('model_spec.IMAGE_CLASSIFICATION_MODELS').export_constant( __name__, 'IMAGE_CLASSIFICATION_MODELS') mm_export('model_spec.TEXT_CLASSIFICATION_MODELS').export_constant( __name__, 'TEXT_CLASSIFICATION_MODELS') mm_export('model_spec.QUESTION_ANSWER_MODELS').export_constant( __name__, 'QUESTION_ANSWER_MODELS') mm_export('model_spec.AUDIO_CLASSIFICATION_MODELS').export_constant( __name__, 'AUDIO_CLASSIFICATION_MODELS') mm_export('model_spec.RECOMMENDATION_MODELS').export_constant( __name__, 'RECOMMENDATION_MODELS') mm_export('model_spec.OBJECT_DETECTION_MODELS').export_constant( __name__, 'OBJECT_DETECTION_MODELS') @mm_export('model_spec.get') def get(spec_or_str): """Gets model spec by name or instance, and initializes by default."""
hparams = train_image_classifier_lib.HParams.get_hparams( batch_size=batch_size, train_epochs=epochs, do_fine_tuning=train_whole_model, dropout_rate=dropout_rate, learning_rate=learning_rate, warmup_steps=warmup_steps, model_dir=model_dir) image_classifier = cls(model_spec, train_data.index_to_label, shuffle=shuffle, hparams=hparams, use_augmentation=use_augmentation, representative_data=train_data) if do_train: tf.compat.v1.logging.info('Retraining the models...') image_classifier.train(train_data, validation_data, steps_per_epoch) else: # Used in evaluation. image_classifier.create_model(with_loss_and_metrics=True) return image_classifier # Shortcut function. create = ImageClassifier.create mm_export('image_classifier.create').export_constant(__name__, 'create')
AUDIO_CLASSIFICATION_MODELS = [ 'audio_browser_fft', 'audio_teachable_machine', 'audio_yamnet' ] RECOMMENDATION_MODELS = [ 'recommendation_bow', 'recommendation_rnn', 'recommendation_cnn', ] OBJECT_DETECTION_MODELS = [ 'efficientdet_lite0', 'efficientdet_lite1', 'efficientdet_lite2', 'efficientdet_lite3', 'efficientdet_lite4', ] mm_export('model_spec.QUESTION_ANSWER_MODELS').export_constant( __name__, 'QUESTION_ANSWER_MODELS') @mm_export('model_spec.get') def get(spec_or_str): """Gets model spec by name or instance, and initializes by default.""" if isinstance(spec_or_str, str): model_spec = MODEL_SPECS[spec_or_str] else: model_spec = spec_or_str if inspect.isclass(model_spec) or inspect.isfunction(model_spec): return model_spec() else: return model_spec
# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an 'AS IS' BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Recommendation dataloader class.""" from tensorflow_examples.lite.model_maker.core.api import mm_export from tensorflow_examples.lite.model_maker.third_party.recommendation.ml.configs import input_config_pb2 from tensorflow_examples.lite.model_maker.third_party.recommendation.ml.configs import model_config # Shortcut for classes. ModelHParams = model_config.ModelConfig mm_export('recommendation.spec.ModelHParams').export_constant( __name__, 'ModelHParams') InputSpec = input_config_pb2.InputConfig Feature = input_config_pb2.Feature FeatureGroup = input_config_pb2.FeatureGroup FeatureType = input_config_pb2.FeatureType EncoderType = input_config_pb2.EncoderType mm_export('recommendation.spec.InputSpec').export_constant( __name__, 'InputSpec') mm_export('recommendation.spec.Feature').export_constant(__name__, 'Feature') mm_export('recommendation.spec.FeatureGroup').export_constant( __name__, 'FeatureGroup') EncoderType.__doc__ = 'EncoderType Enum (valid: BOW, CNN, LSTM).' mm_export('recommendation.spec.EncoderType').export_constant(