Exemplo n.º 1
0
def test_contextnet():
    config = Config(DEFAULT_YAML, learning=False)

    text_featurizer = CharFeaturizer(config.decoder_config)

    speech_featurizer = TFSpeechFeaturizer(config.speech_config)

    model = ContextNet(vocabulary_size=text_featurizer.num_classes, **config.model_config)

    model._build(speech_featurizer.shape)
    model.summary(line_length=150)

    model.add_featurizers(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer
    )

    concrete_func = model.make_tflite_function(timestamp=False).get_concrete_function()
    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.experimental_new_converter = True
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter.convert()

    print("Converted successfully with no timestamp")

    concrete_func = model.make_tflite_function(timestamp=True).get_concrete_function()
    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.experimental_new_converter = True
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter.convert()

    print("Converted successfully with timestamp")
Exemplo n.º 2
0
parser.add_argument("output",
                    type=str,
                    default=None,
                    help="TFLite file path to be exported")

args = parser.parse_args()

assert args.saved and args.output

config = Config(args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
text_featurizer = CharFeaturizer(config.decoder_config)

# build model
contextnet = ContextNet(**config.model_config,
                        vocabulary_size=text_featurizer.num_classes)
contextnet._build(speech_featurizer.shape)
contextnet.load_weights(args.saved)
contextnet.summary(line_length=150)
contextnet.add_featurizers(speech_featurizer, text_featurizer)

concrete_func = contextnet.make_tflite_function(
    greedy=True).get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
]
tflite_model = converter.convert()

if not os.path.exists(os.path.dirname(args.output)):
Exemplo n.º 3
0
        speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
        **vars(config.learning_config.train_dataset_config)
    )
    eval_dataset = ASRSliceDataset(
        speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
        **vars(config.learning_config.eval_dataset_config)
    )

contextnet_trainer = TransducerTrainerGA(
    config=config.learning_config.running_config,
    text_featurizer=text_featurizer, strategy=strategy
)

with contextnet_trainer.strategy.scope():
    # build model
    contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes)
    contextnet._build(speech_featurizer.shape)
    contextnet.summary(line_length=120)

    optimizer = tf.keras.optimizers.Adam(
        TransformerSchedule(
            d_model=contextnet.dmodel,
            warmup_steps=config.learning_config.optimizer_config["warmup_steps"],
            max_lr=(0.05 / math.sqrt(contextnet.dmodel))
        ),
        beta_1=config.learning_config.optimizer_config["beta1"],
        beta_2=config.learning_config.optimizer_config["beta2"],
        epsilon=config.learning_config.optimizer_config["epsilon"]
    )

contextnet_trainer.compile(model=contextnet, optimizer=optimizer,
                                                       args.subwords)
else:
    raise ValueError("subwords must be set")

tf.random.set_seed(0)
assert args.saved

if args.tfrecords:
    test_dataset = ASRTFRecordDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.test_dataset_config))
else:
    test_dataset = ASRSliceDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.test_dataset_config))

# build model
contextnet = ContextNet(**config.model_config,
                        vocabulary_size=text_featurizer.num_classes)
contextnet._build(speech_featurizer.shape)
contextnet.load_weights(args.saved)
contextnet.summary(line_length=120)
contextnet.add_featurizers(speech_featurizer, text_featurizer)

contextnet_tester = BaseTester(config=config.learning_config.running_config,
                               output_name=args.output_name)
contextnet_tester.compile(contextnet)
contextnet_tester.run(test_dataset)
def test_contextnet():
    config = Config(DEFAULT_YAML, learning=False)

    text_featurizer = CharFeaturizer(config.decoder_config)

    speech_featurizer = TFSpeechFeaturizer(config.speech_config)

    model = ContextNet(vocabulary_size=text_featurizer.num_classes,
                       **config.model_config)

    model._build(speech_featurizer.shape)
    model.summary(line_length=150)

    model.add_featurizers(speech_featurizer=speech_featurizer,
                          text_featurizer=text_featurizer)

    concrete_func = model.make_tflite_function(
        timestamp=False).get_concrete_function()
    converter = tf.lite.TFLiteConverter.from_concrete_functions(
        [concrete_func])
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.experimental_new_converter = True
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
    ]
    tflite = converter.convert()

    print("Converted successfully with no timestamp")

    concrete_func = model.make_tflite_function(
        timestamp=True).get_concrete_function()
    converter = tf.lite.TFLiteConverter.from_concrete_functions(
        [concrete_func])
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.experimental_new_converter = True
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
    ]
    converter.convert()

    print("Converted successfully with timestamp")

    tflitemodel = tf.lite.Interpreter(model_content=tflite)
    signal = tf.random.normal([4000])

    input_details = tflitemodel.get_input_details()
    output_details = tflitemodel.get_output_details()
    tflitemodel.resize_tensor_input(input_details[0]["index"], [4000])
    tflitemodel.allocate_tensors()
    tflitemodel.set_tensor(input_details[0]["index"], signal)
    tflitemodel.set_tensor(input_details[1]["index"],
                           tf.constant(text_featurizer.blank, dtype=tf.int32))
    tflitemodel.set_tensor(
        input_details[2]["index"],
        tf.zeros([
            config.model_config["prediction_num_rnns"], 2, 1,
            config.model_config["prediction_rnn_units"]
        ],
                 dtype=tf.float32))
    tflitemodel.invoke()
    hyp = tflitemodel.get_tensor(output_details[0]["index"])

    print(hyp)