Esempio n. 1
0
def test_tacotron2_multi_gpu_trainable(model_dict):
    ngpu = 2
    device_ids = list(range(ngpu))
    bs = 10
    maxin_len = 10
    maxout_len = 10
    idim = 5
    odim = 10
    model_args = make_model_args(**model_dict)
    loss_args = make_loss_args()
    batch = prepare_inputs(bs, idim, odim, maxin_len, maxout_len,
                           model_args['spk_embed_dim'], model_args['spc_dim'])
    batch = (x.cuda() if x is not None else None for x in batch)

    # define model
    tacotron2 = Tacotron2(idim, odim, Namespace(**model_args))
    tacotron2 = torch.nn.DataParallel(tacotron2, device_ids)
    model = Tacotron2Loss(tacotron2, **loss_args)
    optimizer = torch.optim.Adam(model.parameters())
    model.cuda()

    # trainable
    loss = model(*batch)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
Esempio n. 2
0
def decode(args):
    '''RUN DECODING'''
    # read training config
    idim, odim, train_args = get_model_conf(args.model, args.model_conf)

    # show argments
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # define model
    tacotron2 = Tacotron2(idim, odim, train_args)
    eos = str(tacotron2.idim - 1)

    # load trained model parameters
    logging.info('reading model parameters from ' + args.model)
    torch_load(args.model, tacotron2)
    tacotron2.eval()

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    tacotron2 = tacotron2.to(device)

    # read json data
    with open(args.json, 'rb') as f:
        js = json.load(f)['utts']

    # chech direcitory
    outdir = os.path.dirname(args.out)
    if len(outdir) != 0 and not os.path.exists(outdir):
        os.makedirs(outdir)

    # write to ark and scp file (see https://github.com/vesis84/kaldi-io-for-python)
    arkscp = 'ark:| copy-feats --print-args=false ark:- ark,scp:%s.ark,%s.scp' % (
        args.out, args.out)
    with torch.no_grad(), kaldi_io_py.open_or_fd(arkscp, 'wb') as f:
        for idx, utt_id in enumerate(js.keys()):
            x = js[utt_id]['output'][0]['tokenid'].split() + [eos]
            x = np.fromiter(map(int, x), dtype=np.int64)
            x = torch.LongTensor(x).to(device)

            # get speaker embedding
            if train_args.use_speaker_embedding:
                spemb = kaldi_io_py.read_vec_flt(
                    js[utt_id]['input'][1]['feat'])
                spemb = torch.FloatTensor(spemb).to(device)
            else:
                spemb = None

            # decode and write
            outs, _, _ = tacotron2.inference(x, args, spemb)
            if outs.size(0) == x.size(0) * args.maxlenratio:
                logging.warn("output length reaches maximum length (%s)." %
                             utt_id)
            logging.info(
                '(%d/%d) %s (size:%d->%d)' %
                (idx + 1, len(js.keys()), utt_id, x.size(0), outs.size(0)))
            kaldi_io_py.write_mat(f, outs.cpu().numpy(), utt_id)
Esempio n. 3
0
def test_tacotron2_trainable_and_decodable(model_dict, loss_dict):
    # make args
    model_args = make_model_args(**model_dict)
    loss_args = make_loss_args(**loss_dict)
    inference_args = make_inference_args()

    # setup batch
    bs = 2
    maxin_len = 10
    maxout_len = 10
    idim = 5
    odim = 10
    if model_args['use_cbhg']:
        model_args['spc_dim'] = 129
    if model_args['use_speaker_embedding']:
        model_args['spk_embed_dim'] = 128
    batch = prepare_inputs(bs, idim, odim, maxin_len, maxout_len,
                           model_args['spk_embed_dim'], model_args['spc_dim'])
    xs, ilens, ys, labels, olens, spembs, spcs = batch

    # define model
    model = Tacotron2(idim, odim, Namespace(**model_args))
    criterion = Tacotron2Loss(model, **loss_args)
    optimizer = torch.optim.Adam(model.parameters())

    # trainable
    loss = criterion(xs, ilens, ys, labels, olens, spembs, spcs)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # decodable
    model.eval()
    with torch.no_grad():
        spemb = None if model_args['spk_embed_dim'] is None else spembs[0]
        model.inference(xs[0][:ilens[0]], Namespace(**inference_args), spemb)
        att_ws = model.calculate_all_attentions(xs, ilens, ys, spembs)
    assert att_ws.shape[0] == bs
    assert att_ws.shape[1] == max(olens)
    assert att_ws.shape[2] == max(ilens)
Esempio n. 4
0
def train(args):
    '''RUN TRAINING'''
    # seed setting
    torch.manual_seed(args.seed)

    # use determinisitic computation or not
    if args.debugmode < 1:
        torch.backends.cudnn.deterministic = False
        logging.info('torch cudnn deterministic is disabled')
    else:
        torch.backends.cudnn.deterministic = True

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())

    # reverse input and output dimension
    idim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    if args.use_cbhg:
        args.spc_dim = int(valid_json[utts[0]]['input'][1]['shape'][1])
    if args.use_speaker_embedding:
        args.spk_embed_dim = int(valid_json[utts[0]]['input'][1]['shape'][0])
    else:
        args.spk_embed_dim = None
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)), indent=4,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # specify model architecture
    tacotron2 = Tacotron2(idim, odim, args)
    logging.info(tacotron2)

    # check the use of multi-gpu
    if args.ngpu > 1:
        tacotron2 = torch.nn.DataParallel(tacotron2,
                                          device_ids=list(range(args.ngpu)))
        logging.info('batch size is automatically increased (%d -> %d)' %
                     (args.batch_size, args.batch_size * args.ngpu))
        args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    tacotron2 = tacotron2.to(device)

    # define loss
    model = Tacotron2Loss(tacotron2, args.use_masking, args.bce_pos_weight)
    reporter = model.reporter

    # Setup an optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr,
                                 eps=args.eps,
                                 weight_decay=args.weight_decay)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, 'target', reporter)
    setattr(optimizer, 'serialize', lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(True, args.use_speaker_embedding,
                                args.use_cbhg)

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    # make minibatch list (variable length)
    train_batchset = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1)
    valid_batchset = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1)
    # hack to make batchsze argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = chainer.iterators.MultiprocessIterator(
            TransformDataset(train_batchset, converter.transform),
            batch_size=1,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
        valid_iter = chainer.iterators.MultiprocessIterator(
            TransformDataset(valid_batchset, converter.transform),
            batch_size=1,
            repeat=False,
            shuffle=False,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
    else:
        train_iter = chainer.iterators.SerialIterator(TransformDataset(
            train_batchset, converter.transform),
                                                      batch_size=1)
        valid_iter = chainer.iterators.SerialIterator(TransformDataset(
            valid_batchset, converter.transform),
                                                      batch_size=1,
                                                      repeat=False,
                                                      shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer,
                            converter, device)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, converter, device))

    # Save snapshot for each epoch
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # Save best models
    trainer.extend(
        extensions.snapshot_object(tacotron2,
                                   'model.loss.best',
                                   savefun=torch_save),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))

    # Save attention figure for each epoch
    if args.num_save_attention > 0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(tacotron2, "module"):
            att_vis_fn = tacotron2.module.calculate_all_attentions
        else:
            att_vis_fn = tacotron2.calculate_all_attentions
        trainer.extend(PlotAttentionReport(att_vis_fn,
                                           data,
                                           args.outdir + '/att_ws',
                                           converter=CustomConverter(
                                               False,
                                               args.use_speaker_embedding),
                                           device=device,
                                           reverse=True),
                       trigger=(1, 'epoch'))

    # Make a plot for training and validation values
    plot_keys = [
        'main/loss', 'validation/main/loss', 'main/l1_loss',
        'validation/main/l1_loss', 'main/mse_loss', 'validation/main/mse_loss',
        'main/bce_loss', 'validation/main/bce_loss'
    ]
    trainer.extend(
        extensions.PlotReport(['main/l1_loss', 'validation/main/l1_loss'],
                              'epoch',
                              file_name='l1_loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/mse_loss', 'validation/main/mse_loss'],
                              'epoch',
                              file_name='mse_loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/bce_loss', 'validation/main/bce_loss'],
                              'epoch',
                              file_name='bce_loss.png'))
    if args.use_cbhg:
        plot_keys += [
            'main/cbhg_l1_loss', 'validation/main/cbhg_l1_loss',
            'main/cbhg_mse_loss', 'validation/main/cbhg_mse_loss'
        ]
        trainer.extend(
            extensions.PlotReport(
                ['main/cbhg_l1_loss', 'validation/main/cbhg_l1_loss'],
                'epoch',
                file_name='cbhg_l1_loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/cbhg_mse_loss', 'validation/main/cbhg_mse_loss'],
                'epoch',
                file_name='cbhg_mse_loss.png'))
    trainer.extend(
        extensions.PlotReport(plot_keys, 'epoch', file_name='loss.png'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL,
                                                 'iteration')))
    report_keys = plot_keys[:]
    report_keys[0:0] = ['epoch', 'iteration', 'elapsed_time']
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(REPORT_INTERVAL, 'iteration'))
    trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL))

    # Run the training
    trainer.run()