parser.add_argument("--subwords_prefix",
                    type=str,
                    default=None,
                    help="Prefix of file that stores generated subwords")

parser.add_argument("--output_name",
                    type=str,
                    default="test",
                    help="Result filename name prefix")

args = parser.parse_args()

tf.config.optimizer.set_experimental_options(
    {"auto_mixed_precision": args.mxp})

setup_devices([args.device], cpu=args.cpu)

from tiramisu_asr.configs.user_config import UserConfig
from tiramisu_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
from tiramisu_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tiramisu_asr.featurizers.text_featurizers import SubwordFeaturizer
from tiramisu_asr.runners.base_runners import BaseTester
from tiramisu_asr.models.conformer import Conformer

config = UserConfig(DEFAULT_YAML, args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])

if args.subwords_prefix and os.path.exists(f"{args.subwords_prefix}.subwords"):
    print("Loading subwords ...")
    text_featurizer = SubwordFeaturizer.load_from_file(
        config["decoder_config"], args.subwords_prefix)
Esempio n. 2
0
                    action="store_true",
                    help="Enable mixed precision")

parser.add_argument("--device",
                    type=int,
                    default=0,
                    help="Device's id to run test on")

parser.add_argument("--bs", type=int, default=None, help="Batch size")

args = parser.parse_args()

tf.config.optimizer.set_experimental_options(
    {"auto_mixed_precision": args.mxp})

setup_devices([args.device])

from tiramisu_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tiramisu_asr.featurizers.text_featurizers import TextFeaturizer
from tiramisu_asr.configs.user_config import UserConfig
from tiramisu_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
from model import SelfAttentionDS2
from tiramisu_asr.runners.base_runners import BaseTester
from ctc_decoders import Scorer

tf.random.set_seed(0)
assert args.saved

config = UserConfig(DEFAULT_YAML, args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
text_featurizer = TextFeaturizer(config["decoder_config"])