Ejemplo n.º 1
0
import decoder

from warp_rnnt import rnnt_loss

torch.backends.cudnn.benchmark = True
torch.manual_seed(0)
np.random.seed(0)

labels = Labels()

model = Transducer(128,
                   len(labels),
                   512,
                   256,
                   am_layers=3,
                   lm_layers=3,
                   dropout=0.3,
                   am_checkpoint='exp/am.bin',
                   lm_checkpoint='exp/lm.bin')

train = AudioDataset(
    '/media/lytic/STORE/ru_open_stt_wav/public_youtube1120_hq.txt', labels)
test = AudioDataset(
    '/media/lytic/STORE/ru_open_stt_wav/public_youtube700_val.txt', labels)

train.filter_by_conv(model.encoder.conv)
train.filter_by_length(400)

test.filter_by_conv(model.encoder.conv)
test.filter_by_length(200)
Ejemplo n.º 2
0
    subsample_factor=1,
    discourse_aware=False,
    skip_thought=False)

vocab = trainset.vocab
batch, is_new_epoch = trainset.next()
xs = [np2tensor(x).float() for x in batch['xs']]
xlens = torch.IntTensor([len(x) for x in batch['xs']])
xs = pad_list(xs, 0.0)
ys = batch['ys']
_ys = [np2tensor(np.fromiter(y, dtype=np.int64), -1)
       for y in ys]  # // TODO vishay optimize for gpu
ys_out_pad = pad_list(_ys, 0).long()
ylens = np2tensor(np.fromiter([y.size(0) for y in _ys], dtype=np.int32))
# TODO use config file
model = Transducer(81, vocab, 256, 3, args.dropout, bidirectional=args.bi)
print(model)
for param in model.parameters():
    torch.nn.init.uniform(param, -0.1, 0.1)
if args.init: model.load_state_dict(torch.load(args.init))
if args.initam: model.encoder.load_state_dict(torch.load(args.initam))
if args.cuda: model.cuda()

optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                   model.parameters()),
                            lr=args.lr,
                            momentum=.9)

# data set
# trainset = SequentialLoader('train', args.batch_size)
# devset = SequentialLoader('dev', args.batch_size)
Ejemplo n.º 3
0
os.makedirs(args.out, exist_ok=True)
with open(os.path.join(args.out, 'args'), 'w') as f:
    f.write(str(args))
if args.stdout: logging.basicConfig(format='%(asctime)s: %(message)s', datefmt="%m-%d %H:%M:%S", level=logging.INFO)
else: logging.basicConfig(format='%(asctime)s: %(message)s', datefmt="%m-%d %H:%M:%S", filename=os.path.join(args.out, 'train.log'), level=logging.INFO)

context = mx.gpu(0)
# Dataset
trainset = SequentialLoader('train', args.batch_size, context)
devset = SequentialLoader('dev', args.batch_size, context)

###############################################################################
# Build the model
###############################################################################

model = Transducer(40, 128, 2, args.dropout, bidirectional=args.bi)
# model.collect_params().initialize(mx.init.Xavier(), ctx=context)
if args.init:
    model.collect_params().load(args.init, context)
elif args.initam or args.initpm:
    model.initialize(mx.init.Uniform(0.1), ctx=context)
    if args.initam:
        model.collect_params('transducer0_rnnmodel0').load(args.initam, context, True, True)
    if args.initpm:
        model.collect_params('transducer0_rnnmodel1').load(args.initpm, context, True, True)
    # model.collect_params().save(args.out+'/init')
    # print('initial model save to', args.out+'/init')
else:
    model.initialize(mx.init.Uniform(0.1), ctx=context)

trainer = gluon.Trainer(model.collect_params(), 'sgd',
Ejemplo n.º 4
0
def main(_):

    logging.info('Running with parameters:')
    logging.info(json.dumps(FLAGS.flag_values_dict(), indent=4))

    if os.path.exists(os.path.join(FLAGS.model_dir, 'config.json')):

        expect_partial = False
        if FLAGS.mode in ['transcribe-file', 'realtime']:
            expect_partial = True

        model = load_model(FLAGS.model_dir,
            checkpoint=FLAGS.checkpoint, expect_partial=expect_partial)

    else:

        if FLAGS.mode in ['eval', 'transcribe-file', 'realtime']:
            raise Exception('Model not found at path: {}'.format(
                FLAGS.model_dir))

        logging.info('Initializing model from scratch.')

        os.makedirs(FLAGS.model_dir, exist_ok=True)
        model_config_filepath = os.path.join(FLAGS.model_dir, 'config.json')

        vocab = vocabulary.init_vocab()
        vocabulary.save_vocab(vocab, os.path.join(FLAGS.model_dir, 'vocab'))

        model = Transducer(vocab=vocab,
                           encoder_layers=FLAGS.encoder_layers,
                           encoder_size=FLAGS.encoder_size,
                           pred_net_layers=FLAGS.pred_net_layers,
                           pred_net_size=FLAGS.pred_net_size,
                           joint_net_size=FLAGS.joint_net_size,
                           softmax_size=FLAGS.softmax_size)

        model.save_config(model_config_filepath)

        logging.info('Initialized model from scratch.')

    distribution_strategy = None

    if FLAGS.tpu is not None:

        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        distribution_strategy = tf.distribute.experimental.TPUStrategy(
            tpu_cluster_resolver=tpu_cluster_resolver)

    if FLAGS.mode == 'export':
        
        saved_model_dir = os.path.join(FLAGS.model_dir, 'saved_model')
        os.makedirs(saved_model_dir, exist_ok=True)
        
        all_versions = [int(ver) for ver in os.listdir(saved_model_dir)]

        if len(all_versions) > 0:
            version = max(all_versions) + 1
        else:
            version = 1

        export_path = os.path.join(saved_model_dir, str(version))
        os.makedirs(export_path)

        tf.saved_model.save(model, export_path, signatures={
            'serving_default': model.predict
        })

    elif FLAGS.mode == 'transcribe-file':

        transcription = transcribe_file(model, FLAGS.input)

        print('Input file: {}'.format(FLAGS.input))
        print('Transcription: {}'.format(transcription))

    elif FLAGS.mode == 'realtime':

        audio_buf = []
        last_result = None

        def stream_callback(in_data, frame_count, time_info, status):
            audio_buf.append(in_data)
            return None, pyaudio.paContinue

        def audio_gen():
            while True:
                if len(audio_buf) > 0:
                    audio_data = audio_buf[0]
                    audio_arr = np.frombuffer(audio_data, dtype=np.float32)
                    yield audio_arr

        FORMAT = pyaudio.paFloat32
        CHANNELS = 1
        RATE = 16000
        CHUNK = 2048

        audio = pyaudio.PyAudio()
        stream = audio.open(format=FORMAT,
                            channels=CHANNELS,
                            rate=RATE,
                            input=True,
                            frames_per_buffer=CHUNK,
                            stream_callback=stream_callback)
        
        stream.start_stream()

        outputs = transcribe_stream(model, audio_gen(), RATE)

        print('Transcribing live audio (press CTRL+C to stop)...')

        for (output, is_final) in outputs:
            if output != last_result and output != '' and not is_final:
                print('Partial Result: {}'.format(output))
                last_result = output
            if is_final:
                print('# Final Result: {}'.format(output))
                last_result = None

    else:

        if FLAGS.dataset_name == 'common-voice':
            data_utils = utils.data.common_voice

        train_dataset, dev_dataset = data_utils.create_datasets(FLAGS.dataset_path,
            max_data=FLAGS.max_data)

        if dev_dataset is None:
            dev_dataset = train_dataset.take(FLAGS.eval_size)
            train_dataset = train_dataset.skip(FLAGS.eval_size)

        if FLAGS.mode == 'eval':

            logging.info('Begin evaluation...')

            loss, acc = do_eval(model, dev_dataset,
                                batch_size=FLAGS.batch_size,
                                shuffle_buffer_size=FLAGS.shuffle_buffer_size,
                                distribution_strategy=distribution_strategy)

            logging.info('Evaluation complete: Loss {} Accuracy {}'.format(
                loss, acc))

        else:

            optimizer = tf.keras.optimizers.Adam(FLAGS.learning_rate)

            checkpoints_path = os.path.join(FLAGS.model_dir, 'checkpoints')
            os.makedirs(checkpoints_path, exist_ok=True)

            do_train(model, train_dataset, optimizer,
                     FLAGS.epochs, FLAGS.batch_size,
                     eval_dataset=dev_dataset,
                     steps_per_checkpoint=FLAGS.steps_per_checkpoint,
                     checkpoint_path=checkpoints_path,
                     steps_per_log=FLAGS.steps_per_log,
                     tb_log_dir=FLAGS.tb_log_dir,
                     keep_top_n=FLAGS.keep_top,
                     shuffle_buffer_size=FLAGS.shuffle_buffer_size,
                     distribution_strategy=distribution_strategy)
Ejemplo n.º 5
0
from data import Labels, split_train_dev_test
from model import Transducer
from utils import AverageMeter
from warp_rnnt import rnnt_loss

torch.backends.cudnn.benchmark = True
torch.manual_seed(1)
np.random.seed(1)

labels = Labels()

blank = torch.tensor([labels.blank()], dtype=torch.int).cuda()
space = torch.tensor([labels.space()], dtype=torch.int).cuda()

model = Transducer(128, len(labels), 512, 256, am_layers=3, lm_layers=3, dropout=0.3,
                   am_checkpoint='runs/Dec14_21-53-45_ctc_bs32x4_gn200/model20.bin',
                   lm_checkpoint='runs/Dec09_22-04-47_lm_bptt8_bs64_gn1_do0.3/model10.bin')
model.cuda()

train, dev, test = split_train_dev_test(
    '/open-stt-e2e/data/',
    labels, model.am.conv, batch_size=32
)

parameters = [
    {'params': model.fc.parameters(), 'lr': 3e-5},
    {'params': model.am.parameters(), 'lr': 3e-5},
    {'params': model.lm.parameters(), 'lr': 3e-5}
]

optimizer = torch.optim.Adam(parameters, weight_decay=1e-5)
Ejemplo n.º 6
0
def main(_):

    logging.info('Running with parameters:')
    logging.info(json.dumps(FLAGS.flag_values_dict(), indent=4))

    if os.path.exists(os.path.join(FLAGS.model_dir, 'config.json')):

        expect_partial = False
        if FLAGS.mode in ['transcribe-file', 'realtime', 'export']:
            expect_partial = True

        model = load_model(FLAGS.model_dir,
            checkpoint=FLAGS.checkpoint, expect_partial=expect_partial)

    else:

        if FLAGS.mode in ['eval', 'transcribe-file', 'realtime']:
            raise Exception('Model not found at path: {}'.format(
                FLAGS.model_dir))

        logging.info('Initializing model from scratch.')

        os.makedirs(FLAGS.model_dir, exist_ok=True)
        model_config_filepath = os.path.join(FLAGS.model_dir, 'config.json')

        vocab = vocabulary.init_vocab()
        vocabulary.save_vocab(vocab, os.path.join(FLAGS.model_dir, 'vocab'))

        model = Transducer(vocab=vocab,
                           encoder_layers=FLAGS.encoder_layers,
                           encoder_size=FLAGS.encoder_size,
                           pred_net_layers=FLAGS.pred_net_layers,
                           joint_net_size=FLAGS.joint_net_size,
                           softmax_size=FLAGS.softmax_size)

        model.save_config(model_config_filepath)

        logging.info('Initialized model from scratch.')

    distribution_strategy = None

    if FLAGS.tpu is not None:

        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        distribution_strategy = tf.distribute.experimental.TPUStrategy(
            tpu_cluster_resolver=tpu_cluster_resolver)

    if FLAGS.mode == 'export':
        
        # saved_model_dir = os.path.join(FLAGS.model_dir, 'saved_model')
        # os.makedirs(saved_model_dir, exist_ok=True)
        
        # all_versions = [int(ver) for ver in os.listdir(saved_model_dir)]

        # if len(all_versions) > 0:
        #     version = max(all_versions) + 1
        # else:
        #     version = 1

        # export_path = os.path.join(saved_model_dir, str(version))
        # os.makedirs(export_path)

        # tf.saved_model.save(model, export_path, signatures={
        #     'serving_default': model.predict
        # })

        # print(model.predict(tf.zeros((1, 1024)), tf.constant([16000]), tf.constant(['hell']), tf.zeros((1, 2, 1, 2048))))

        tflite_dir = os.path.join(FLAGS.model_dir, 'lite')
        os.makedirs(tflite_dir, exist_ok=True)

        concrete_func = model.predict.get_concrete_function(
            audio=tf.TensorSpec([1, 1024], dtype=tf.float32),
            sr=tf.TensorSpec([1], dtype=tf.int32),
            pred_inp=tf.TensorSpec([1], dtype=tf.string),
            enc_state=tf.TensorSpec([1, 2, 1, model.encoder_size], dtype=tf.float32))

        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
                                               tf.lite.OpsSet.SELECT_TF_OPS]
        converter.experimental_new_converter = True
        converter.experimental_new_quantizer = True
        converter.allow_custom_ops = True

        # def representative_dataset_gen():
        #     dataset, _ = load_datasets()
        #     for i in range(10):
        #         yield [next(dataset)]

        # converter.optimizations = [tf.lite.Optimize.DEFAULT]
        # converter.representative_dataset = representative_dataset_gen
        # converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
        # converter.inference_input_type = tf.uint8
        # converter.inference_output_type = tf.uint8

        tflite_quant_model = converter.convert()
        
        with open(os.path.join(tflite_dir, 'model.tflite'), 'wb') as f:
            f.write(tflite_quant_model)

        print('Exported model to TFLite.')

    elif FLAGS.mode == 'transcribe-file':

        transcription = transcribe_file(model, FLAGS.input)

        print('Input file: {}'.format(FLAGS.input))
        print('Transcription: {}'.format(transcription))

    elif FLAGS.mode == 'realtime':

        import pyaudio

        audio_buf = []
        last_result = None

        def stream_callback(in_data, frame_count, time_info, status):
            audio_buf.append(in_data)
            return None, pyaudio.paContinue

        def audio_gen():
            while True:
                if len(audio_buf) > 0:
                    audio_data = audio_buf[0]
                    audio_arr = np.frombuffer(audio_data, dtype=np.float32)
                    yield audio_arr

        FORMAT = pyaudio.paFloat32
        CHANNELS = 1
        RATE = 16000
        CHUNK = 2048

        audio = pyaudio.PyAudio()
        stream = audio.open(format=FORMAT,
                            channels=CHANNELS,
                            rate=RATE,
                            input=True,
                            frames_per_buffer=CHUNK,
                            stream_callback=stream_callback)
        
        stream.start_stream()

        outputs = transcribe_stream(model, audio_gen(), RATE)

        print('Transcribing live audio (press CTRL+C to stop)...')

        for (output, is_final) in outputs:
            if output != last_result and output != '' and not is_final:
                print('Partial Result: {}'.format(output))
                last_result = output
            if is_final:
                print('# Final Result: {}'.format(output))
                last_result = None

    else:

        train_dataset, dev_dataset = load_datasets()

        if dev_dataset is None:
            dev_dataset = train_dataset.take(FLAGS.eval_size)
            train_dataset = train_dataset.skip(FLAGS.eval_size)

        if FLAGS.eval_size:
            dev_dataset = dev_dataset.take(FLAGS.eval_size)

        if FLAGS.mode == 'eval':

            logging.info('Begin evaluation...')

            loss, acc = do_eval(model, dev_dataset,
                                batch_size=FLAGS.batch_size,
                                shuffle_buffer_size=FLAGS.shuffle_buffer_size,
                                distribution_strategy=distribution_strategy)

            logging.info('Evaluation complete: Loss {} Accuracy {}'.format(
                loss, acc))

        else:

            optimizer = tf.keras.optimizers.Adam(FLAGS.learning_rate)

            checkpoints_path = os.path.join(FLAGS.model_dir, 'checkpoints')
            os.makedirs(checkpoints_path, exist_ok=True)

            do_train(model, train_dataset, optimizer,
                     FLAGS.epochs, FLAGS.batch_size,
                     eval_dataset=dev_dataset,
                     steps_per_checkpoint=FLAGS.steps_per_checkpoint,
                     checkpoint_path=checkpoints_path,
                     steps_per_log=FLAGS.steps_per_log,
                     tb_log_dir=FLAGS.tb_log_dir,
                     keep_top_n=FLAGS.keep_top,
                     shuffle_buffer_size=FLAGS.shuffle_buffer_size,
                     distribution_strategy=distribution_strategy)
Ejemplo n.º 7
0
from model import Transducer
from utils import AverageMeter
from warp_rnnt import rnnt_loss

torch.backends.cudnn.benchmark = True
torch.manual_seed(2)
np.random.seed(2)

labels = Labels()

blank = torch.tensor([labels.blank()], dtype=torch.int).cuda()
space = torch.tensor([labels.space()], dtype=torch.int).cuda()

model_path = 'runs/Jan06_19-51-52_rnnt_bs32x4_gn200_beta0.5/model10.bin'

model = Transducer(128, len(labels), 512, 256, am_layers=3, lm_layers=3, dropout=0.3)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.cuda()

train, dev, test = split_train_dev_test(
    '/open-stt-e2e/data/',
    labels, model.am.conv, batch_size=16
)

parameters = [
    {'params': model.fc.parameters(), 'lr': 3e-6},
    {'params': model.am.parameters(), 'lr': 3e-6},
    {'params': model.lm.parameters(), 'lr': 3e-6}
]

optimizer = torch.optim.Adam(parameters, weight_decay=1e-5)
def main(cmd_args):
    """Run the main training function."""

    parser = get_parser()
    args, _ = parser.parse_known_args(cmd_args)

    # logging info
    if args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    # set random seed
    logging.info("random seed = %d" % args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    # load dictionary for debug log
    if args.dict is not None:
        with open(args.dict, "rb") as f:
            dictionary = f.readlines()
        char_list = [
            entry.decode("utf-8").split(" ")[0] for entry in dictionary
        ]
        char_list.insert(0, "<blank>")
        char_list.append("<eos>")
        args.char_list = char_list
    else:
        args.char_list = None

    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())
    idim = int((valid_json[utts[0]]["input"][0]["shape"][-1]))
    odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])

    logging.info("input dims: " + str(idim))
    logging.info("#output dims: " + str(odim))

    # data
    Data = Dataloader(args)

    # model
    Model = Transducer(idim, odim, args)

    # update saved model
    call_back = ModelCheckpoint(monitor='val_loss', dirpath=args.outdir)

    # train model
    trainer = Trainer(gpus=args.ngpu,
                      callbacks=[call_back],
                      max_epochs=args.epochs,
                      resume_from_checkpoint=args.resume)
    trainer.fit(Model, Data)
Ejemplo n.º 9
0
def main(args):
    """Run the main decoding function."""
    parser = get_parser()
    args = parser.parse_args(args)

    # logging info
    if args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    with open(args.recog_json, "rb") as f:
        test_json = json.load(f)["utts"]
    utts = list(test_json.keys())
    idim = int((test_json[utts[0]]["input"][0]["shape"][-1]))
    odim = int(test_json[utts[0]]["output"][0]["shape"][-1])

    load_test = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        preprocess_conf=None,
        preprocess_args={"train": False},  # Switch the mode of preprocessing
    )

    if args.dict is not None:
        with open(args.dict, "rb") as f:
            dictionary = f.readlines()
        char_list = [
            entry.decode("utf-8").split(" ")[0] for entry in dictionary
        ]
        char_list.insert(0, "<blank>")
        char_list.append("<eos>")
        args.char_list = char_list
    else:
        args.char_list = None

    Model = Transducer.load_from_checkpoint(args.model,
                                            idim=idim,
                                            odim=odim,
                                            args=args)

    new_js = {}
    with torch.no_grad():
        for idx, name in enumerate(test_json.keys(), 1):
            logging.info("(%d/%d) decoding " + name, idx,
                         len(test_json.keys()))
            batch = [(name, test_json[name])]
            feat = load_test(batch)[0][0]

            nbest_hyps = Model.recog(feat)

            new_js[name] = add_results_to_json(test_json[name], nbest_hyps,
                                               args.char_list)

    with open(args.result_label, "wb") as f:
        f.write(
            json.dumps({
                "utts": new_js
            },
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode("utf_8"))