예제 #1
0
def infer(args):
    system_info.print_system_info()

    # Prepare model
    in_size = feature.get_input_dim(args.frame_size, args.context_size,
                                    args.input_transform)

    if args.model_type == 'Transformer':
        model = TransformerDiarization(
            in_size,
            n_units=args.hidden_size,
            n_heads=args.transformer_encoder_n_heads,
            n_layers=args.transformer_encoder_n_layers,
            dropout=0,
            alpha=0)
    else:
        raise ValueError('Unknown model type.')

    serializers.load_npz(args.model_file, model)

    if args.gpu >= 0:
        gpuid = use_single_gpu()
        model.to_gpu()

    kaldi_obj = kaldi_data.KaldiData(args.data_dir)
    for recid in kaldi_obj.wavs:
        data, rate = kaldi_obj.load_wav(recid)
        Y = feature.stft(data, args.frame_size, args.frame_shift)
        Y = feature.transform(Y, transform_type=args.input_transform)
        Y = feature.splice(Y, context_size=args.context_size)
        Y = Y[::args.subsampling]
        out_chunks = []
        with chainer.no_backprop_mode(), chainer.using_config('train', False):
            hs = None
            for start, end in _gen_chunk_indices(len(Y), args.chunk_size):
                Y_chunked = Variable(Y[start:end])
                if args.gpu >= 0:
                    Y_chunked.to_gpu(gpuid)
                hs, ys = model.estimate_sequential(hs, [Y_chunked])
                if args.gpu >= 0:
                    ys[0].to_cpu()
                out_chunks.append(ys[0].data)
                if args.save_attention_weight == 1:
                    att_fname = f"{recid}_{start}_{end}.att.npy"
                    att_path = os.path.join(args.out_dir, att_fname)
                    model.save_attention_weight(att_path)
        outfname = recid + '.h5'
        outpath = os.path.join(args.out_dir, outfname)
        if hasattr(model, 'label_delay'):
            outdata = shift(np.vstack(out_chunks), (-model.label_delay, 0))
        else:
            outdata = np.vstack(out_chunks)
        with h5py.File(outpath, 'w') as wf:
            wf.create_dataset('T_hat', data=outdata)
예제 #2
0
def train(args):
    """ Training model with chainer backend.
    This function is called from eend/bin/train.py with
    parsed command-line arguments.
    """
    np.random.seed(args.seed)
    os.environ['CHAINER_SEED'] = str(args.seed)
    chainer.global_config.cudnn_deterministic = True

    train_set = KaldiDiarizationDataset(
        args.train_data_dir,
        chunk_size=args.num_frames,
        context_size=args.context_size,
        input_transform=args.input_transform,
        frame_size=args.frame_size,
        frame_shift=args.frame_shift,
        subsampling=args.subsampling,
        rate=args.sampling_rate,
        use_last_samples=True,
        label_delay=args.label_delay,
        n_speakers=args.num_speakers,
        )
    dev_set = KaldiDiarizationDataset(
        args.valid_data_dir,
        chunk_size=args.num_frames,
        context_size=args.context_size,
        input_transform=args.input_transform,
        frame_size=args.frame_size,
        frame_shift=args.frame_shift,
        subsampling=args.subsampling,
        rate=args.sampling_rate,
        use_last_samples=True,
        label_delay=args.label_delay,
        n_speakers=args.num_speakers,
        )

    # Prepare model
    Y, T = train_set.get_example(0)

    if args.model_type == 'BLSTM':
        model = BLSTMDiarization(
                in_size=Y.shape[1],
                n_speakers=args.num_speakers,
                hidden_size=args.hidden_size,
                n_layers=args.num_lstm_layers,
                embedding_layers=args.embedding_layers,
                embedding_size=args.embedding_size,
                dc_loss_ratio=args.dc_loss_ratio,
                )
    elif args.model_type == 'Transformer':
        model = TransformerDiarization(
                args.num_speakers,
                Y.shape[1],
                n_units=args.hidden_size,
                n_heads=args.transformer_encoder_n_heads,
                n_layers=args.transformer_encoder_n_layers,
                dropout=args.transformer_encoder_dropout)
    else:
        raise ValueError('Possible model_type are "Transformer" and "BLSTM"')

    if args.gpu >= 0:
        gpuid = use_single_gpu()
        print('GPU device {} is used'.format(gpuid))
        model.to_gpu()
    else:
        gpuid = -1
    print('Prepared model')

    # Setup optimizer
    if args.optimizer == 'adam':
        optimizer = optimizers.Adam(alpha=args.lr)
    elif args.optimizer == 'sgd':
        optimizer = optimizers.SGD(lr=args.lr)
    elif args.optimizer == 'noam':
        optimizer = optimizers.Adam(alpha=0, beta1=0.9, beta2=0.98, eps=1e-9)
    else:
        raise ValueError(args.optimizer)

    optimizer.setup(model)
    if args.gradclip > 0:
        optimizer.add_hook(
            chainer.optimizer_hooks.GradientClipping(args.gradclip))

    # Init/Resume
    if args.initmodel:
        print('Load model from', args.initmodel)
        serializers.load_npz(args.initmodel, model)

    train_iter = iterators.MultiprocessIterator(
            train_set,
            batch_size=args.batchsize,
            repeat=True, shuffle=True,
            # shared_mem=64000000,
            shared_mem=None,
            n_processes=4, n_prefetch=2)

    dev_iter = iterators.MultiprocessIterator(
            dev_set,
            batch_size=args.batchsize,
            repeat=False, shuffle=False,
            # shared_mem=64000000,
            shared_mem=None,
            n_processes=4, n_prefetch=2)

    if args.gradient_accumulation_steps > 1:
        updater = GradientAccumulationUpdater(
            train_iter, optimizer, converter=_convert, device=gpuid)
    else:
        updater = training.StandardUpdater(
            train_iter, optimizer, converter=_convert, device=gpuid)

    trainer = training.Trainer(
            updater,
            (args.max_epochs, 'epoch'),
            out=os.path.join(args.model_save_dir))

    evaluator = extensions.Evaluator(
            dev_iter, model, converter=_convert, device=gpuid)
    trainer.extend(evaluator)

    if args.optimizer == 'noam':
        trainer.extend(
            NoamScheduler(args.hidden_size,
                          warmup_steps=args.noam_warmup_steps,
                          scale=args.noam_scale),
            trigger=(1, 'iteration'))

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # MICRO AVERAGE
    metrics = [
            ('diarization_error', 'speaker_scored', 'DER'),
            ('speech_miss', 'speech_scored', 'SAD_MR'),
            ('speech_falarm', 'speech_scored', 'SAD_FR'),
            ('speaker_miss', 'speaker_scored', 'MI'),
            ('speaker_falarm', 'speaker_scored', 'FA'),
            ('speaker_error', 'speaker_scored', 'CF'),
            ('correct', 'frames', 'accuracy')]
    for num, den, name in metrics:
        trainer.extend(extensions.MicroAverage(
            'main/{}'.format(num),
            'main/{}'.format(den),
            'main/{}'.format(name)))
        trainer.extend(extensions.MicroAverage(
            'validation/main/{}'.format(num),
            'validation/main/{}'.format(den),
            'validation/main/{}'.format(name)))

    trainer.extend(extensions.LogReport(log_name='log_iter',
                   trigger=(1000, 'iteration')))

    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/diarization_error_rate',
         'validation/main/diarization_error_rate',
         'elapsed_time']))
    trainer.extend(extensions.PlotReport(
        ['main/loss', 'validation/main/loss'],
        x_key='epoch',
        file_name='loss.png'))
    trainer.extend(extensions.PlotReport(
        ['main/diarization_error_rate',
         'validation/main/diarization_error_rate'],
        x_key='epoch',
        file_name='DER.png'))
    trainer.extend(extensions.ProgressBar(update_interval=100))
    trainer.extend(extensions.snapshot(
        filename='snapshot_epoch-{.updater.epoch}'))

    trainer.extend(extensions.dump_graph('main/loss', out_name="cg.dot"))

    trainer.run()
    print('Finished!')