Esempio n. 1
0
def recog(args):
    '''Run recognition'''
    # seed setting
    torch.manual_seed(args.seed)

    # read training config
    idim, odim, train_args = get_model_conf(args.model, args.model_conf)

    # load trained model parameters
    logging.info('reading model parameters from ' + args.model)
    e2e = E2E(idim, odim, train_args)
    model = Loss(e2e, train_args.mtlalpha)
    torch_load(args.model, model)

    # read rnnlm
    if args.rnnlm:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(train_args.char_list), rnnlm_args.unit))
        torch_load(args.rnnlm, rnnlm)
        rnnlm.eval()
    else:
        rnnlm = None

    if args.word_rnnlm:
        if not args.word_dict:
            logging.error(
                'word dictionary file is not specified for the word RNNLM.')
            sys.exit(1)

        rnnlm_args = get_model_conf(args.word_rnnlm, args.rnnlm_conf)
        word_dict = load_labeldict(args.word_dict)
        char_dict = {x: i for i, x in enumerate(train_args.char_list)}
        word_rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(word_dict), rnnlm_args.unit))
        torch_load(args.word_rnnlm, word_rnnlm)
        word_rnnlm.eval()

        if rnnlm is not None:
            rnnlm = lm_pytorch.ClassifierWithState(
                extlm_pytorch.MultiLevelLM(word_rnnlm.predictor,
                                           rnnlm.predictor, word_dict,
                                           char_dict))
        else:
            rnnlm = lm_pytorch.ClassifierWithState(
                extlm_pytorch.LookAheadWordLM(word_rnnlm.predictor, word_dict,
                                              char_dict))

    # read json data
    with open(args.recog_json, 'rb') as f:
        js = json.load(f)['utts']

    # decode each utterance
    new_js = {}
    with torch.no_grad():
        for idx, name in enumerate(js.keys(), 1):
            logging.info('(%d/%d) decoding ' + name, idx, len(js.keys()))
            feat = kaldi_io_py.read_mat(js[name]['input'][0]['feat'])
            nbest_hyps = e2e.recognize(feat, args, train_args.char_list, rnnlm)
            new_js[name] = add_results_to_json(js[name], nbest_hyps,
                                               train_args.char_list)

    # TODO(watanabe) fix character coding problems when saving it
    with open(args.result_label, 'wb') as f:
        f.write(
            json.dumps({
                'utts': new_js
            }, indent=4, sort_keys=True).encode('utf_8'))
def train(args):
    '''Run training'''
    # seed setting
    torch.manual_seed(args.seed)

    # debug mode setting
    # 0 would be fastest, but 1 seems to be reasonable
    # by considering reproducability
    # revmoe type check
    if args.debugmode < 2:
        chainer.config.type_check = False
        logging.info('torch type check is disabled')
    # use determinisitic computation or not
    if args.debugmode < 1:
        torch.backends.cudnn.deterministic = False
        logging.info('torch cudnn deterministic is disabled')
    else:
        torch.backends.cudnn.deterministic = True

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim_asr = int(valid_json[utts[0]]['input'][0]['shape'][1])
    idim_tts = int(valid_json[utts[0]]['input'][1]['shape'][1])
    odim_asr = int(valid_json[utts[0]]['output'][0]['shape'][1])
    assert idim_tts == idim_asr - 3
    logging.info('#input dims for ASR: ' + str(idim_asr))
    logging.info('#input dims for TTS: ' + str(idim_tts))
    logging.info('#output dims: ' + str(odim_asr))
    if args.tts_use_speaker_embedding:
        args.tts_spk_embed_dim = int(
            valid_json[utts[0]]['input'][2]['shape'][0])
    else:
        args.tts_spk_embed_dim = None

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    # specify model architecture for ASR
    e2e_asr = E2E(idim_asr, odim_asr, args)
    logging.info(e2e_asr)
    asr_loss = Loss(e2e_asr, args.mtlalpha)

    # specify model architecture for TTS
    # reverse input and output dimension
    tts_loss = setup_tts_loss(odim_asr, idim_tts, args)
    logging.info(tts_loss)

    # define loss
    model = ASRTTSLoss(asr_loss, tts_loss, args)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.conf'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to' + model_conf)
        # TODO(watanabe) use others than pickle, possibly json, and save as a text
        pickle.dump((idim_asr, odim_asr, args), f)
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # Set gpu
    ngpu = args.ngpu
    if ngpu == 1:
        gpu_id = range(ngpu)
        logging.info('gpu id: ' + str(gpu_id))
        model.cuda()
    elif ngpu > 1:
        gpu_id = range(ngpu)
        logging.info('gpu id: ' + str(gpu_id))
        model = torch.nn.DataParallel(model, device_ids=gpu_id)
        model.cuda()
        logging.info('batch size is automatically increased (%d -> %d)' %
                     (args.batch_size, args.batch_size * args.ngpu))
        args.batch_size *= args.ngpu
    else:
        gpu_id = [-1]

    # Setup an optimizer
    dummy_target = chainer.Chain()
    opts = {}
    #opts['asr'] = torch.optim.Adadelta(model.asr_loss.parameters(), rho=0.95, eps=args.eps)
    if ngpu > 1:
        module = model.module
    else:
        module = model

    optim_class = getattr(torch.optim, args.optim)
    opts['asr'] = optim_class(module.asr_loss.parameters(),
                              args.lr * args.asr_weight,
                              eps=args.eps,
                              weight_decay=args.weight_decay)
    opts['tts'] = optim_class(module.tts_loss.parameters(),
                              args.lr * args.tts_weight,
                              eps=args.eps,
                              weight_decay=args.weight_decay)
    opts['s2s'] = optim_class(module.ae_speech.parameters(),
                              args.lr * args.s2s_weight,
                              eps=args.eps,
                              weight_decay=args.weight_decay)
    opts['t2t'] = optim_class(module.ae_text.parameters(),
                              args.lr * args.t2t_weight,
                              eps=args.eps,
                              weight_decay=args.weight_decay)
    ae_param = list(
        set(
            list(module.ae_speech.parameters()) +
            list(module.ae_text.parameters())))
    opts['mmd'] = optim_class(ae_param,
                              args.lr * args.mmd_weight,
                              eps=args.eps,
                              weight_decay=args.weight_decay)
    for key in ['asr', 'tts', 's2s', 't2t', 'mmd']:
        # FIXME: TOO DIRTY HACK
        setattr(opts[key], "target", dummy_target)
        setattr(opts[key], "serialize", lambda s: dummy_target.serialize(s))

    # read json data
    logging.warning("reading json")
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    # read utt2mode scp
    logging.warning("reading utt2mode")

    def load_utt2mode(scp, utts):
        with open(scp, 'r') as f:
            for line in f:
                k, v = line.strip().split()
                if k in utts.keys():
                    utts[k]['utt2mode'] = v

    load_utt2mode(args.train_utt2mode, train_json)
    load_utt2mode(args.valid_utt2mode, valid_json)

    # make minibatch list (variable length)
    train, train_multi = make_batchset_asrtts(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        half_unpair=args.use_mmd_autoencoding)
    valid, valid_multi = make_batchset_asrtts(valid_json, args.batch_size,
                                              args.maxlen_in, args.maxlen_out,
                                              args.minibatches)
    # hack to make batchsze argument as 1
    # actual bathsize is included in a list
    train_iter = chainer.iterators.SerialIterator(train_multi["pair"], 1)
    valid_iter = chainer.iterators.SerialIterator(valid,
                                                  1,
                                                  repeat=False,
                                                  shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(model,
                            args.grad_clip,
                            train_iter,
                            train_multi,
                            opts,
                            converter=converter_kaldi,
                            device=gpu_id,
                            args=args)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)
        if ngpu > 1:
            model.module.load_state_dict(
                torch.load(args.outdir + '/model.acc.best'))
        else:
            model.load_state_dict(torch.load(args.outdir + '/model.acc.best'))
    if args.model:
        if ngpu > 1:
            model.module.load_state_dict(torch.load(args.model))
        else:
            model.load_state_dict(torch.load(args.model))
    elif args.model_asr:
        if ngpu > 1:
            model.asr_loss.module.load_state_dict(torch.load(args.model_asr))
        else:
            model.asr_loss.load_state_dict(torch.load(args.model_asr))
    elif args.model_tts:
        if ngpu > 1:
            model.tts_loss.module.load_state_dict(torch.load(args.model_tts))
        else:
            model.tts_loss.load_state_dict(torch.load(args.model_tts))
        model = trainer.updater.model

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluater(model,
                        valid_iter,
                        dummy_target,
                        converter=converter_kaldi,
                        device=gpu_id))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        data = converter_kaldi([data],
                               device=gpu_id,
                               use_speaker_embedding=args.tts_spk_embed_dim)
        trainer.extend(CustomPlotAttentionReport(model, True, data,
                                                 args.outdir + "/att_ws_asr"),
                       trigger=(1, 'epoch'))
        trainer.extend(CustomPlotAttentionReport(model,
                                                 False,
                                                 data,
                                                 args.outdir + '/att_ws_tts',
                                                 reverse=True),
                       trigger=(1, 'epoch'))

    # Take a snapshot for each specified epoch
    trainer.extend(extensions.snapshot(), trigger=(1, 'epoch'))

    # Make a plot for training and validation values
    # report keys
    report_keys = [
        't/loss', 't/tts_loss', 't/s2s_loss', 't/asr_loss', 't/t2t_loss',
        't/mmd_loss', 't/ae_mmd_loss', 'd/loss', 'd/tts_loss', 'd/s2s_loss',
        'd/asr_loss', 'd/t2t_loss'
    ]
    trainer.extend(
        extensions.PlotReport(report_keys, 'epoch', file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['t/asr_acc', 'd/asr_acc'],
                              'epoch',
                              file_name='acc.png'))

    # Save best models
    def torch_save(path, _):
        if ngpu > 1:
            torch.save(model.module.state_dict(), path)
            torch.save(model.module, path + ".pkl")
        else:
            torch.save(model.state_dict(), path)
            torch.save(model, path + ".pkl")

    trainer.extend(extensions.snapshot_object(model,
                                              'model.loss.best',
                                              savefun=torch_save),
                   trigger=training.triggers.MinValueTrigger('d/loss'))

    def torch_load(path, obj):
        if ngpu > 1:
            model.module.load_state_dict(torch.load(path))
        else:
            model.load_state_dict(torch.load(path))
        return obj

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL,
                                                 'iteration')))
    report_keys = [
        't/loss', 't/tts_loss', 't/s2s_loss', 't/asr_loss', 't/t2t_loss',
        't/asr_acc', 't/t2t_acc', 't/mmd_loss', 't/ae_mmd_loss', 'd/loss',
        'd/tts_loss', 'd/s2s_loss', 'd/asr_loss', 'd/t2t_loss', 'd/asr_acc',
        'd/t2t_acc'
    ]
    report_keys.append('epoch')
    report_keys.append('iteration')
    report_keys.append('elapsed_time')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(REPORT_INTERVAL, 'iteration'))

    trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL))

    # Run the training
    trainer.run()
Esempio n. 3
0
def train(args):
    '''Run training'''
    # seed setting
    torch.manual_seed(args.seed)

    # debug mode setting
    # 0 would be fastest, but 1 seems to be reasonable
    # by considering reproducability
    # revmoe type check
    if args.debugmode < 2:
        chainer.config.type_check = False
        logging.info('torch type check is disabled')
    # use determinisitic computation or not
    if args.debugmode < 1:
        torch.backends.cudnn.deterministic = False
        logging.info('torch cudnn deterministic is disabled')
    else:
        torch.backends.cudnn.deterministic = True

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    # specify model architecture
    e2e = E2E(idim, odim, args)
    model = Loss(e2e, args.mtlalpha)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)), indent=4,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        logging.info('batch size is automatically increased (%d -> %d)' %
                     (args.batch_size, args.batch_size * args.ngpu))
        args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters())

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(e2e.subsample[0])

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    # make minibatch list (variable length)
    train = make_batchset(train_json, args.batch_size, args.maxlen_in,
                          args.maxlen_out, args.minibatches)
    valid = make_batchset(valid_json, args.batch_size, args.maxlen_in,
                          args.maxlen_out, args.minibatches)
    # hack to make batchsze argument as 1
    # actual bathsize is included in a list
    train_iter = chainer.iterators.MultiprocessIterator(
        TransformDataset(train, converter.transform),
        batch_size=1,
        n_processes=1,
        n_prefetch=8)  #, maxtasksperchild=20)
    valid_iter = chainer.iterators.SerialIterator(TransformDataset(
        valid, converter.transform),
                                                  batch_size=1,
                                                  repeat=False,
                                                  shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer,
                            converter, device, args.ngpu)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, converter, device))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.predictor.calculate_all_attentions
        else:
            att_vis_fn = model.predictor.calculate_all_attentions
        trainer.extend(PlotAttentionReport(att_vis_fn,
                                           data,
                                           args.outdir + "/att_ws",
                                           converter=converter,
                                           device=device),
                       trigger=(1, 'epoch'))

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))

    # Save best models
    trainer.extend(
        extensions.snapshot_object(model,
                                   'model.loss.best',
                                   savefun=torch_save),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode is not 'ctc':
        trainer.extend(
            extensions.snapshot_object(model,
                                       'model.acc.best',
                                       savefun=torch_save),
            trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # save snapshot which contains model and optimizer states
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode is not 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL,
                                                 'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').
            param_groups[0]["eps"]),
                       trigger=(REPORT_INTERVAL, 'iteration'))
        report_keys.append('eps')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(REPORT_INTERVAL, 'iteration'))

    trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL))

    # Run the training
    trainer.run()
Esempio n. 4
0
def test_asrtts_model_trainable_and_decodable():
    from asrtts_pytorch import setup_tts_loss
    from test_e2e_tts import make_inference_args
    args = make_args()
    print(args)

    # asr setup
    asr_xpad, asr_xlens, asr_ypad = prepare_asr_inputs("pytorch")
    asr_batchsize = asr_xpad.shape[0]
    asr_idim = 40
    asr_odim = 5
    asr_model = Loss(E2E(asr_idim, asr_odim, args), args.mtlalpha)
    # asr trainable
    asr_loss = asr_model(asr_xpad, asr_xlens, asr_ypad)
    asr_loss.backward()
    # asr decodable
    asr_model.eval()
    with torch.no_grad():
        in_data = np.random.randn(100, 40)
        asr_model.predictor.recognize(in_data, args,
                                      args.char_list)  # decodable
    asr_model.train()

    # tts setup
    tts_batchsize = asr_batchsize  # 2
    tts_maxin_len = 10
    tts_maxout_len = 10
    tts_batch = prepare_tts_inputs(tts_batchsize, asr_odim, asr_idim - 3,
                                   tts_maxin_len, tts_maxout_len)
    tts_xpad, tts_ilens, tts_ypad, tts_labels, tts_olens = tts_batch
    setattr(args, "tts_spk_embed_dim", 2)
    spembs = torch.randn(tts_batchsize, args.tts_spk_embed_dim)
    tts_model = setup_tts_loss(asr_odim, asr_idim - 3, args)
    # tts trainable
    tts_loss = tts_model(*tts_batch, spembs)
    tts_loss.backward()  # trainable
    # tts decodable
    tts_model.eval()
    with torch.no_grad():
        spemb = spembs[0]
        x = tts_xpad[0][:tts_ilens[0]]
        yhat, probs, att_ws = tts_model.model.inference(
            x, Namespace(**make_inference_args()), spemb)
        att_ws = tts_model.model.calculate_all_attentions(
            tts_xpad, tts_ilens, tts_ypad, spembs)
    tts_model.train()

    # asrtts model trainable
    model = ASRTTSLoss(asr_model, tts_model, args)
    opts = {}
    #opts['asr'] = torch.optim.Adadelta(model.asr_loss.parameters(), rho=0.95, eps=args.eps)
    opts['asr'] = torch.optim.Adam(model.asr_loss.parameters(),
                                   args.lr * 0.1,
                                   eps=args.eps,
                                   weight_decay=args.weight_decay)
    opts['tts'] = torch.optim.Adam(model.tts_loss.parameters(),
                                   args.lr,
                                   eps=args.eps,
                                   weight_decay=args.weight_decay)
    opts['s2s'] = torch.optim.Adam(model.ae_speech.parameters(),
                                   args.lr * 0.01,
                                   eps=args.eps,
                                   weight_decay=args.weight_decay)
    opts['t2t'] = torch.optim.Adam(model.ae_text.parameters(),
                                   args.lr * 0.01,
                                   eps=args.eps,
                                   weight_decay=args.weight_decay)

    ae_param = list(
        set(
            list(model.ae_speech.parameters()) +
            list(model.ae_text.parameters())))
    opts['mmd'] = torch.optim.Adam(ae_param,
                                   args.lr * 0.01,
                                   eps=args.eps,
                                   weight_decay=args.weight_decay)

    # data prep
    dummy_tokenid = "1 2 3"
    asr_data = [
        (
            "tmp",
            dict(
                feat_asr=asr_xpad[i, :asr_xlens[i]].numpy(),
                feat_tts=tts_ypad[i, :tts_olens[i]].numpy(),
                feat_spembs=spembs[0].numpy(),
                output=[
                    dict(
                        tokenid=
                        dummy_tokenid,  # " ".join(map(str, asr_ypad[i].tolist())),
                        shape=[3, asr_odim])
                ])  # dict
        )  # tuple
        for i in range(asr_batchsize)
    ]
    tts_data = []

    # speech-to-speech
    s2s_loss, hspad, hslen = model.ae_speech(asr_data, return_hidden=True)
    s2s_loss.backward(retain_graph=True)
    # text-to-text
    t2t_loss, t2t_acc, htpad, htlen = model.ae_text(asr_data,
                                                    return_hidden=True)
    t2t_loss.backward(retain_graph=True)
    # inter-domain loss
    mmd_loss = packed_mmd(hspad, hslen, htpad, htlen)
    mmd_loss.backward()
Esempio n. 5
0
def recog(args):
    '''Run recognition'''
    # seed setting
    torch.manual_seed(args.seed)

    # read training config
    idim, odim, train_args = get_model_conf(args.model, args.model_conf)

    # load trained model parameters
    logging.info('reading model parameters from ' + args.model)
    e2e = E2E(idim, odim, train_args)
    model = Loss(e2e, train_args.mtlalpha)
    torch_load(args.model, model)

    # read rnnlm
    if args.rnnlm:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(train_args.char_list), rnnlm_args.unit))
        torch_load(args.rnnlm, rnnlm)
        rnnlm.eval()
    else:
        rnnlm = None

    if args.word_rnnlm:
        if not args.word_dict:
            logging.error('word dictionary file is not specified for the word RNNLM.')
            sys.exit(1)

        rnnlm_args = get_model_conf(args.word_rnnlm, args.rnnlm_conf)
        word_dict = load_labeldict(args.word_dict)
        char_dict = {x: i for i, x in enumerate(train_args.char_list)}
        word_rnnlm = lm_pytorch.ClassifierWithState(lm_pytorch.RNNLM(len(word_dict), rnnlm_args.unit))
        torch_load(args.word_rnnlm, word_rnnlm)
        word_rnnlm.eval()

        if rnnlm is not None:
            rnnlm = lm_pytorch.ClassifierWithState(
                extlm_pytorch.MultiLevelLM(word_rnnlm.predictor,
                                           rnnlm.predictor, word_dict, char_dict))
        else:
            rnnlm = lm_pytorch.ClassifierWithState(
                extlm_pytorch.LookAheadWordLM(word_rnnlm.predictor,
                                              word_dict, char_dict))

    # read json data
    with open(args.recog_json, 'rb') as f:
        recog_json = json.load(f)['utts']

    new_json = {}
    with torch.no_grad():
        for name in recog_json.keys():
            feat = kaldi_io_py.read_mat(recog_json[name]['input'][0]['feat'])
            nbest_hyps = e2e.recognize(feat, args, train_args.char_list, rnnlm=rnnlm)
            # get 1best and remove sos
            y_hat = nbest_hyps[0]['yseq'][1:]
            y_true = map(int, recog_json[name]['output'][0]['tokenid'].split())

            # print out decoding result
            seq_hat = [train_args.char_list[int(idx)] for idx in y_hat]
            seq_true = [train_args.char_list[int(idx)] for idx in y_true]
            seq_hat_text = "".join(seq_hat).replace('<space>', ' ')
            seq_true_text = "".join(seq_true).replace('<space>', ' ')
            logging.info("groundtruth[%s]: " + seq_true_text, name)
            logging.info("prediction [%s]: " + seq_hat_text, name)

            # copy old json info
            new_json[name] = dict()
            new_json[name]['utt2spk'] = recog_json[name]['utt2spk']

            # added recognition results to json
            logging.debug("dump token id")
            out_dic = dict()
            for _key in recog_json[name]['output'][0]:
                out_dic[_key] = recog_json[name]['output'][0][_key]

            # TODO(karita) make consistent to chainer as idx[0] not idx
            out_dic['rec_tokenid'] = " ".join([str(idx) for idx in y_hat])
            logging.debug("dump token")
            out_dic['rec_token'] = " ".join(seq_hat)
            logging.debug("dump text")
            out_dic['rec_text'] = seq_hat_text

            new_json[name]['output'] = [out_dic]
            # TODO(nelson): Modify this part when saving more than 1 hyp is enabled
            # add n-best recognition results with scores
            if args.beam_size > 1 and len(nbest_hyps) > 1:
                for i, hyp in enumerate(nbest_hyps):
                    y_hat = hyp['yseq'][1:]
                    seq_hat = [train_args.char_list[int(idx)] for idx in y_hat]
                    seq_hat_text = "".join(seq_hat).replace('<space>', ' ')
                    new_json[name]['rec_tokenid' + '[' + '{:05d}'.format(i) + ']'] = \
                        " ".join([str(idx) for idx in y_hat])
                    new_json[name]['rec_token' + '[' + '{:05d}'.format(i) + ']'] = " ".join(seq_hat)
                    new_json[name]['rec_text' + '[' + '{:05d}'.format(i) + ']'] = seq_hat_text
                    new_json[name]['score' + '[' + '{:05d}'.format(i) + ']'] = hyp['score']

    # TODO(watanabe) fix character coding problems when saving it
    with open(args.result_label, 'wb') as f:
        f.write(json.dumps({'utts': new_json}, indent=4, sort_keys=True).encode('utf_8'))
Esempio n. 6
0
def recog(args):
    '''Run recognition'''
    # seed setting
    torch.manual_seed(args.seed)

    # read training config
    idim, odim, train_args = get_model_conf(args.model, args.model_conf)

    # load trained model parameters
    logging.info('reading model parameters from ' + args.model)
    e2e = E2E(idim, odim, train_args)
    model = Loss(e2e, train_args.mtlalpha)
    torch_load(args.model, model)
    e2e.recog_args = args

    # read rnnlm
    if args.rnnlm:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(
                len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit))
        torch_load(args.rnnlm, rnnlm)
        rnnlm.eval()
    else:
        rnnlm = None

    if args.word_rnnlm:
        rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
        word_dict = rnnlm_args.char_list_dict
        char_dict = {x: i for i, x in enumerate(train_args.char_list)}
        word_rnnlm = lm_pytorch.ClassifierWithState(lm_pytorch.RNNLM(
            len(word_dict), rnnlm_args.layer, rnnlm_args.unit))
        torch_load(args.word_rnnlm, word_rnnlm)
        word_rnnlm.eval()

        if rnnlm is not None:
            rnnlm = lm_pytorch.ClassifierWithState(
                extlm_pytorch.MultiLevelLM(word_rnnlm.predictor,
                                           rnnlm.predictor, word_dict, char_dict))
        else:
            rnnlm = lm_pytorch.ClassifierWithState(
                extlm_pytorch.LookAheadWordLM(word_rnnlm.predictor,
                                              word_dict, char_dict))

    # gpu
    if args.ngpu == 1:
        gpu_id = range(args.ngpu)
        logging.info('gpu id: ' + str(gpu_id))
        model.cuda()
        if rnnlm:
            rnnlm.cuda()

    # read json data
    with open(args.recog_json, 'rb') as f:
        js = json.load(f)['utts']
    new_js = {}

    if args.batchsize is None:
        with torch.no_grad():
            for idx, name in enumerate(js.keys(), 1):
                logging.info('(%d/%d) decoding ' + name, idx, len(js.keys()))
                feat = kaldi_io_py.read_mat(js[name]['input'][0]['feat'])
                nbest_hyps = e2e.recognize(feat, args, train_args.char_list, rnnlm)
                new_js[name] = add_results_to_json(js[name], nbest_hyps, train_args.char_list)
    else:
        try:
            from itertools import zip_longest as zip_longest
        except Exception:
            from itertools import izip_longest as zip_longest

        def grouper(n, iterable, fillvalue=None):
            kargs = [iter(iterable)] * n
            return zip_longest(*kargs, fillvalue=fillvalue)

        # sort data
        keys = js.keys()
        feat_lens = [js[key]['input'][0]['shape'][0] for key in keys]
        sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
        keys = [keys[i] for i in sorted_index]

        with torch.no_grad():
            for names in grouper(args.batchsize, keys, None):
                names = [name for name in names if name]
                feats = [kaldi_io_py.read_mat(js[name]['input'][0]['feat'])
                         for name in names]
                nbest_hyps = e2e.recognize_batch(feats, args, train_args.char_list, rnnlm=rnnlm)
                for i, nbest_hyp in enumerate(nbest_hyps):
                    name = names[i]
                    new_js[name] = add_results_to_json(js[name], nbest_hyp, train_args.char_list)

    # TODO(watanabe) fix character coding problems when saving it
    with open(args.result_label, 'wb') as f:
        f.write(json.dumps({'utts': new_js}, indent=4, sort_keys=True).encode('utf_8'))
def recog(args):
    '''Run recognition'''

    # rnnlm
    import extlm_pytorch
    import lm_pytorch

    # seed setting
    torch.manual_seed(args.seed)

    # read training config
    with open(args.model_conf, "rb") as f:
        logging.info('reading a model config file from' + args.model_conf)
        idim_asr, odim_asr, train_args = pickle.load(f)

    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # specify model architecture
    logging.info('reading model parameters from' + args.model)
    e2e_asr = E2E(idim_asr, odim_asr, train_args)
    logging.info(e2e_asr)
    asr_loss = Loss(e2e_asr, train_args.mtlalpha)

    # specify model architecture for TTS
    # reverse input and output dimension
    tts_loss = setup_tts_loss(odim_asr, idim_asr - 3, train_args)
    logging.info(tts_loss)

    # define loss
    model = ASRTTSLoss(asr_loss, tts_loss, train_args)

    def cpu_loader(storage, location):
        return storage

    def remove_dataparallel(state_dict):
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            if k.startswith("module."):
                k = k[7:]
            new_state_dict[k] = v
        return new_state_dict

    model.load_state_dict(
        remove_dataparallel(torch.load(args.model, map_location=cpu_loader)))

    # read rnnlm
    if args.rnnlm:
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(train_args.char_list), 650))
        rnnlm.load_state_dict(torch.load(args.rnnlm, map_location=cpu_loader))
        rnnlm.eval()
    else:
        rnnlm = None

    if args.word_rnnlm:
        if not args.word_dict:
            logging.error(
                'word dictionary file is not specified for the word RNNLM.')
            sys.exit(1)

        word_dict = load_labeldict(args.word_dict)
        char_dict = {x: i for i, x in enumerate(train_args.char_list)}
        word_rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(word_dict), 650))
        word_rnnlm.load_state_dict(
            torch.load(args.word_rnnlm, map_location=cpu_loader))
        word_rnnlm.eval()

        if rnnlm is not None:
            rnnlm = lm_pytorch.ClassifierWithState(
                extlm_pytorch.MultiLevelLM(word_rnnlm.predictor,
                                           rnnlm.predictor, word_dict,
                                           char_dict))
        else:
            rnnlm = lm_pytorch.ClassifierWithState(
                extlm_pytorch.LookAheadWordLM(word_rnnlm.predictor, word_dict,
                                              char_dict))

    # read json data
    with open(args.recog_json, 'rb') as f:
        recog_json = json.load(f)['utts']

    new_json = {}
    for name in recog_json.keys():
        feat = kaldi_io_py.read_mat(recog_json[name]['input'][0]['feat'])
        nbest_hyps = e2e_asr.recognize(feat,
                                       args,
                                       train_args.char_list,
                                       rnnlm=rnnlm)
        # get 1best and remove sos
        y_hat = nbest_hyps[0]['yseq'][1:]
        y_true = map(int, recog_json[name]['output'][0]['tokenid'].split())

        # print out decoding result
        seq_hat = [train_args.char_list[int(idx)] for idx in y_hat]
        seq_true = [train_args.char_list[int(idx)] for idx in y_true]
        seq_hat_text = "".join(seq_hat).replace('<space>', ' ')
        seq_true_text = "".join(seq_true).replace('<space>', ' ')
        logging.info("groundtruth[%s]: " + seq_true_text, name)
        logging.info("prediction [%s]: " + seq_hat_text, name)

        # copy old json info
        new_json[name] = dict()
        new_json[name]['utt2spk'] = recog_json[name]['utt2spk']

        # added recognition results to json
        logging.debug("dump token id")
        out_dic = dict()
        for _key in recog_json[name]['output'][0]:
            out_dic[_key] = recog_json[name]['output'][0][_key]

        # TODO(karita) make consistent to chainer as idx[0] not idx
        out_dic['rec_tokenid'] = " ".join([str(idx) for idx in y_hat])
        logging.debug("dump token")
        out_dic['rec_token'] = " ".join(seq_hat)
        logging.debug("dump text")
        out_dic['rec_text'] = seq_hat_text

        new_json[name]['output'] = [out_dic]
        # TODO(nelson): Modify this part when saving more than 1 hyp is enabled
        # add n-best recognition results with scores
        if args.beam_size > 1 and len(nbest_hyps) > 1:
            for i, hyp in enumerate(nbest_hyps):
                y_hat = hyp['yseq'][1:]
                seq_hat = [train_args.char_list[int(idx)] for idx in y_hat]
                seq_hat_text = "".join(seq_hat).replace('<space>', ' ')
                new_json[name]['rec_tokenid' + '[' + '{:05d}'.format(i) +
                               ']'] = " ".join([str(idx) for idx in y_hat])
                new_json[name]['rec_token' + '[' + '{:05d}'.format(i) +
                               ']'] = " ".join(seq_hat)
                new_json[name]['rec_text' + '[' + '{:05d}'.format(i) +
                               ']'] = seq_hat_text
                new_json[name]['score' + '[' + '{:05d}'.format(i) +
                               ']'] = hyp['score']

    # TODO(watanabe) fix character coding problems when saving it
    with open(args.result_label, 'wb') as f:
        f.write(
            json.dumps({
                'utts': new_json
            }, indent=4, sort_keys=True).encode('utf_8'))
def tts_decode(args):
    '''RUN DECODING'''
    # read training config
    # idim, odim, train_args = get_model_conf(args.model, args.model_conf)
    # seed setting
    torch.manual_seed(args.seed)

    # show argments
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # read training config
    with open(args.model_conf, "rb") as f:
        logging.info('reading a model config file from' + args.model_conf)
        idim_asr, odim_asr, train_args = pickle.load(f)

    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # specify model architecture
    logging.info('reading model parameters from' + args.model)
    e2e_asr = E2E(idim_asr, odim_asr, train_args)
    logging.info(e2e_asr)
    asr_loss = Loss(e2e_asr, train_args.mtlalpha)

    # specify model architecture for TTS
    # reverse input and output dimension
    tts_loss = setup_tts_loss(odim_asr, idim_asr - 3, train_args)
    logging.info(tts_loss)

    # define loss
    model = ASRTTSLoss(asr_loss, tts_loss, train_args)

    def cpu_loader(storage, location):
        return storage

    def remove_dataparallel(state_dict):
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            if k.startswith("module."):
                k = k[7:]
            new_state_dict[k] = v
        return new_state_dict

    model.load_state_dict(
        remove_dataparallel(torch.load(args.model, map_location=cpu_loader)))

    # define model
    tacotron2 = Tacotron2(idim, odim, train_args)
    eos = str(tacotron2.idim - 1)

    # load trained model parameters
    logging.info('reading model parameters from ' + args.model)
    torch_load(args.model, tacotron2)
    tacotron2.eval()

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    tacotron2 = tacotron2.to(device)

    # read json data
    with open(args.json, 'rb') as f:
        js = json.load(f)['utts']

    # chech direcitory
    outdir = os.path.dirname(args.out)
    if len(outdir) != 0 and not os.path.exists(outdir):
        os.makedirs(outdir)

    # write to ark and scp file (see https://github.com/vesis84/kaldi-io-for-python)
    arkscp = 'ark:| copy-feats --print-args=false ark:- ark,scp:%s.ark,%s.scp' % (
        args.out, args.out)
    with torch.no_grad(), kaldi_io_py.open_or_fd(arkscp, 'wb') as f:
        for idx, utt_id in enumerate(js.keys()):
            x = js[utt_id]['output'][0]['tokenid'].split() + [eos]
            x = np.fromiter(map(int, x), dtype=np.int64)
            x = torch.LongTensor(x).to(device)

            # get speaker embedding
            if train_args.use_speaker_embedding:
                spemb = kaldi_io_py.read_vec_flt(
                    js[utt_id]['input'][1]['feat'])
                spemb = torch.FloatTensor(spemb).to(device)
            else:
                spemb = None

            # decode and write
            outs, _, _ = tacotron2.inference(x, args, spemb)
            if outs.size(0) == x.size(0) * args.maxlenratio:
                logging.warn("output length reaches maximum length (%s)." %
                             utt_id)
            logging.info(
                '(%d/%d) %s (size:%d->%d)' %
                (idx + 1, len(js.keys()), utt_id, x.size(0), outs.size(0)))
            kaldi_io_py.write_mat(f, outs.cpu().numpy(), utt_id)