Ejemplo n.º 1
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    with kaldiio.ReadHelper(args.rspecifier,
                            segments=args.segments) as reader, \
            file_writer_helper(args.wspecifier,
                               filetype=args.filetype,
                               write_num_frames=args.write_num_frames,
                               compress=args.compress,
                               compression_method=args.compression_method
                               ) as writer:
        for utt_id, (_, array) in reader:
            array = array.astype(numpy.float32)
            if args.normalize is not None and args.normalize != 1:
                array = array / (1 << (args.normalize - 1))
            spc = spectrogram(x=array,
                              n_fft=args.n_fft,
                              n_shift=args.n_shift,
                              win_length=args.win_length,
                              window=args.window)
            writer[utt_id] = spc
Ejemplo n.º 2
0
def load_and_convert_examples(args, tokenizer, transformer):
    examples = read_examples(args.input_file)

    logger.info("Training number: %s", str(len(examples)))
    with file_writer_helper(
            args.wspecifier,
            filetype=args.filetype,
            compress=args.compress,
            compression_method=args.compression_method) as writer:
        convert_examples_to_features(
            examples,
            writer,
            args.max_seq_length,
            tokenizer,
            transformer,
            cls_token_at_end=bool(args.model_type in ['xlnet']
                                  ),  # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            sep_token=tokenizer.sep_token,
            sep_token_extra=bool(args.model_type in ['roberta']),
            cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
            pad_on_left=bool(
                args.model_type in ['xlnet']),  # pad on the left for xlnet
            pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
            device=args.device)
Ejemplo n.º 3
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    if args.preprocess_conf is not None:
        preprocessing = Transformation(args.preprocess_conf)
        logging.info('Apply preprocessing: {}'.format(preprocessing))
    else:
        preprocessing = None

    with file_writer_helper(args.wspecifier,
                            filetype=args.filetype,
                            write_num_frames=args.write_num_frames,
                            compress=args.compress,
                            compression_method=args.compression_method,
                            pcm_format=args.format) as writer:
        for utt_id, (rate,
                     array) in kaldiio.ReadHelper(args.rspecifier,
                                                  args.segments):
            if args.filetype == 'mat':
                # Kaldi-matrix doesn't support integer
                array = array.astype(numpy.float32)

            if array.ndim == 1:
                # (Time) -> (Time, Channel)
                array = array[:, None]

            if args.normalize is not None and args.normalize != 1:
                array = array.astype(numpy.float32)
                array = array / (1 << (args.normalize - 1))

            if preprocessing is not None:
                orgtype = array.dtype
                out = preprocessing(array, uttid_list=utt_id)
                out = out.astype(orgtype)

                if args.keep_length:
                    if len(out) > len(array):
                        out = numpy.pad(out, [(0, len(out) - len(array))] +
                                        [(0, 0) for _ in range(out.ndim - 1)],
                                        mode='constant')
                    elif len(out) < len(array):
                        # The length can be changed by stft, for example.
                        out = out[:len(out)]

                array = out

            # shape = (Time, Channel)
            if args.filetype in ['sound.hdf5', 'sound']:
                # Write Tuple[int, numpy.ndarray] (scipy style)
                writer[utt_id] = (rate, array)
            else:
                writer[utt_id] = array
Ejemplo n.º 4
0
def test_KaldiReader(tmpdir, filetype):
    ark = str(tmpdir.join('a.foo'))
    scp = str(tmpdir.join('a.scp'))
    fs = 16000

    with file_writer_helper(wspecifier=f'ark,scp:{ark},{scp}',
                            filetype=filetype,
                            write_num_frames='ark,t:out.txt',
                            compress=False,
                            compression_method=2,
                            pcm_format='wav') as writer:

        if 'sound' in filetype:
            aaa = np.random.randint(-10, 10, 100, dtype=np.int16)
            bbb = np.random.randint(-10, 10, 50, dtype=np.int16)
        else:
            aaa = np.random.randn(10, 10)
            bbb = np.random.randn(13, 5)
        if 'sound' in filetype:
            writer['aaa'] = fs, aaa
            writer['bbb'] = fs, bbb
        else:
            writer['aaa'] = aaa
            writer['bbb'] = bbb
        valid = {'aaa': aaa, 'bbb': bbb}

    # 1. Test ark read
    if filetype != 'sound':
        for key, value in file_reader_helper(f'ark:{ark}',
                                             filetype=filetype,
                                             return_shape=False):
            if 'sound' in filetype:
                assert_scipy_wav_style(value)
                value = value[1]
            np.testing.assert_array_equal(value, valid[key])
    # 2. Test scp read
    for key, value in file_reader_helper(f'scp:{scp}',
                                         filetype=filetype,
                                         return_shape=False):
        if 'sound' in filetype:
            assert_scipy_wav_style(value)
            value = value[1]
        np.testing.assert_array_equal(value, valid[key])

    # 3. Test ark shape read
    if filetype != 'sound':
        for key, value in file_reader_helper(f'ark:{ark}',
                                             filetype=filetype,
                                             return_shape=True):
            if 'sound' in filetype:
                value = value[1]
            np.testing.assert_array_equal(value, valid[key].shape)
    # 4. Test scp shape read
    for key, value in file_reader_helper(f'scp:{scp}',
                                         filetype=filetype,
                                         return_shape=True):
        if 'sound' in filetype:
            value = value[1]
        np.testing.assert_array_equal(value, valid[key].shape)
Ejemplo n.º 5
0
def main():
    args = get_parser().parse_args()

    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    if ":" in args.stats_rspecifier_or_rxfilename:
        is_rspcifier = True
        if args.stats_filetype == "npy":
            stats_filetype = "hdf5"
        else:
            stats_filetype = args.stats_filetype

        stats_dict = dict(
            file_reader_helper(args.stats_rspecifier_or_rxfilename,
                               stats_filetype))
    else:
        is_rspcifier = False
        if args.stats_filetype == "mat":
            stats = kaldiio.load_mat(args.stats_rspecifier_or_rxfilename)
        else:
            stats = numpy.load(args.stats_rspecifier_or_rxfilename)
        stats_dict = {None: stats}

    cmvn = CMVN(
        stats=stats_dict,
        norm_means=args.norm_means,
        norm_vars=args.norm_vars,
        utt2spk=args.utt2spk,
        spk2utt=args.spk2utt,
        reverse=args.reverse,
    )

    with file_writer_helper(
            args.wspecifier,
            filetype=args.out_filetype,
            write_num_frames=args.write_num_frames,
            compress=args.compress,
            compression_method=args.compression_method,
    ) as writer:
        for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype):
            if is_scipy_wav_style(mat):
                # If data is sound file, then got as Tuple[int, ndarray]
                rate, mat = mat
            mat = cmvn(mat, utt if is_rspcifier else None)
            writer[utt] = mat
Ejemplo n.º 6
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    # Find the number of utterances
    n_utt = sum(1 for line in open(args.segments))
    logging.info("%d utterances found to be processed." % n_utt)

    # Compute fbank features
    with kaldiio.ReadHelper(
            args.rspecifier,
            segments=args.segments) as reader, file_writer_helper(
                args.wspecifier,
                filetype=args.filetype,
                write_num_frames=args.write_num_frames,
                compress=args.compress,
                compression_method=args.compression_method,
            ) as writer:
        for i, struct in enumerate(reader, start=1):
            logging.info("processing %d/%d(%.2f%%)" %
                         (i, n_utt, 100 * i / n_utt))
            utt_id, (rate, array) = struct
            try:
                assert rate == args.fs
                array = array.astype(numpy.float32)
                if args.normalize is not None and args.normalize != 1:
                    array = array / (1 << (args.normalize - 1))

                lmspc = logmelspectrogram(
                    x=array,
                    fs=args.fs,
                    n_mels=args.n_mels,
                    n_fft=args.n_fft,
                    n_shift=args.n_shift,
                    win_length=args.win_length,
                    window=args.window,
                    fmin=args.fmin,
                    fmax=args.fmax,
                )
                writer[utt_id] = lmspc
            except:
                logging.warning("failed to compute fbank for utt_id=`%s`" %
                                utt_id)
Ejemplo n.º 7
0
def main():
  parser = get_parser()
  args = parser.parse_args()

  d = kaldiio.load_ark(args.rspecifier)

  with file_writer_helper(
      args.wspecifier,
      filetype='mat',
      write_num_frames=args.write_num_frames,
      compress=args.compress,
      compression_method=args.compression_method) as writer:
    for utt, mat in d:
      writer[utt] = mat
Ejemplo n.º 8
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    with VideoReader(args.rspecifier) as reader, file_writer_helper(
            args.wspecifier,
            filetype=args.filetype,
            write_num_frames=args.write_num_frames,
            compress=args.compress,
            compression_method=args.compression_method,
    ) as writer:
        for utt_id, v_feature in reader:
            writer[utt_id] = v_feature
Ejemplo n.º 9
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    if args.preprocess_conf is not None:
        preprocessing = Transformation(args.preprocess_conf)
        logging.info("Apply preprocessing: {}".format(preprocessing))
    else:
        preprocessing = None

    with file_writer_helper(
            args.wspecifier,
            filetype=args.out_filetype,
            write_num_frames=args.write_num_frames,
            compress=args.compress,
            compression_method=args.compression_method,
    ) as writer:
        for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype):
            if is_scipy_wav_style(mat):
                # If data is sound file, then got as Tuple[int, ndarray]
                rate, mat = mat

            if preprocessing is not None:
                mat = preprocessing(mat, uttid_list=utt)

            # shape = (Time, Channel)
            if args.out_filetype in ["sound.hdf5", "sound"]:
                # Write Tuple[int, numpy.ndarray] (scipy style)
                writer[utt] = (rate, mat)
            else:
                writer[utt] = mat
Ejemplo n.º 10
0
def enhance(args):
    """Dumping enhanced speech and mask.

    Args:
        args (namespace): The program arguments.
    """
    set_deterministic_pytorch(args)
    # read training config
    idim, odim, train_args = get_model_conf(args.model, args.model_conf)

    # load trained model parameters
    logging.info('reading model parameters from ' + args.model)
    model_class = dynamic_import(train_args.model_module)
    model = model_class(idim, odim, train_args)
    assert isinstance(model, ASRInterface)
    torch_load(args.model, model)
    model.recog_args = args

    # gpu
    if args.ngpu == 1:
        gpu_id = list(range(args.ngpu))
        logging.info('gpu id: ' + str(gpu_id))
        model.cuda()

    # read json data
    with open(args.recog_json, 'rb') as f:
        js = json.load(f)['utts']

    load_inputs_and_targets = LoadInputsAndTargets(
        mode='asr',
        load_output=False,
        sort_in_input_length=False,
        preprocess_conf=None  # Apply pre_process in outer func
    )
    if args.batchsize == 0:
        args.batchsize = 1

    # Creates writers for outputs from the network
    if args.enh_wspecifier is not None:
        enh_writer = file_writer_helper(args.enh_wspecifier,
                                        filetype=args.enh_filetype)
    else:
        enh_writer = None

    # Creates a Transformation instance
    preprocess_conf = (train_args.preprocess_conf if
                       args.preprocess_conf is None else args.preprocess_conf)
    if preprocess_conf is not None:
        logging.info('Use preprocessing'.format(preprocess_conf))
        transform = Transformation(preprocess_conf)
    else:
        transform = None

    # Creates a IStft instance
    istft = None
    frame_shift = args.istft_n_shift  # Used for plot the spectrogram
    if args.apply_istft:
        if preprocess_conf is not None:
            # Read the conffile and find stft setting
            with open(preprocess_conf) as f:
                # Json format: e.g.
                #    {"process": [{"type": "stft",
                #                  "win_length": 400,
                #                  "n_fft": 512, "n_shift": 160,
                #                  "window": "han"},
                #                 {"type": "foo", ...}, ...]}
                conf = json.load(f)
                assert 'process' in conf, conf
                # Find stft setting
                for p in conf['process']:
                    if p['type'] == 'stft':
                        istft = IStft(win_length=p['win_length'],
                                      n_shift=p['n_shift'],
                                      window=p.get('window', 'hann'))
                        logging.info('stft is found in {}. '
                                     'Setting istft config from it\n{}'.format(
                                         preprocess_conf, istft))
                        frame_shift = p['n_shift']
                        break
        if istft is None:
            # Set from command line arguments
            istft = IStft(win_length=args.istft_win_length,
                          n_shift=args.istft_n_shift,
                          window=args.istft_window)
            logging.info(
                'Setting istft config from the command line args\n{}'.format(
                    istft))

    # sort data
    keys = list(js.keys())
    feat_lens = [js[key]['input'][0]['shape'][0] for key in keys]
    sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
    keys = [keys[i] for i in sorted_index]

    def grouper(n, iterable, fillvalue=None):
        kargs = [iter(iterable)] * n
        return zip_longest(*kargs, fillvalue=fillvalue)

    num_images = 0
    if not os.path.exists(args.image_dir):
        os.makedirs(args.image_dir)

    for names in grouper(args.batchsize, keys, None):
        batch = [(name, js[name]) for name in names]

        # May be in time region: (Batch, [Time, Channel])
        org_feats = load_inputs_and_targets(batch)[0]
        if transform is not None:
            # May be in time-freq region: : (Batch, [Time, Channel, Freq])
            feats = transform(org_feats, train=False)
        else:
            feats = org_feats

        with torch.no_grad():
            enhanced, mask, ilens = model.enhance(feats)

        for idx, name in enumerate(names):
            # Assuming mask, feats : [Batch, Time, Channel. Freq]
            #          enhanced    : [Batch, Time, Freq]
            enh = enhanced[idx][:ilens[idx]]
            mas = mask[idx][:ilens[idx]]
            feat = feats[idx]

            # Plot spectrogram
            if args.image_dir is not None and num_images < args.num_images:
                import matplotlib.pyplot as plt
                num_images += 1
                ref_ch = 0

                plt.figure(figsize=(20, 10))
                plt.subplot(4, 1, 1)
                plt.title('Mask [ref={}ch]'.format(ref_ch))
                plot_spectrogram(plt,
                                 mas[:, ref_ch].T,
                                 fs=args.fs,
                                 mode='linear',
                                 frame_shift=frame_shift,
                                 bottom=False,
                                 labelbottom=False)

                plt.subplot(4, 1, 2)
                plt.title('Noisy speech [ref={}ch]'.format(ref_ch))
                plot_spectrogram(plt,
                                 feat[:, ref_ch].T,
                                 fs=args.fs,
                                 mode='db',
                                 frame_shift=frame_shift,
                                 bottom=False,
                                 labelbottom=False)

                plt.subplot(4, 1, 3)
                plt.title('Masked speech [ref={}ch]'.format(ref_ch))
                plot_spectrogram(plt, (feat[:, ref_ch] * mas[:, ref_ch]).T,
                                 frame_shift=frame_shift,
                                 fs=args.fs,
                                 mode='db',
                                 bottom=False,
                                 labelbottom=False)

                plt.subplot(4, 1, 4)
                plt.title('Enhanced speech')
                plot_spectrogram(plt,
                                 enh.T,
                                 fs=args.fs,
                                 mode='db',
                                 frame_shift=frame_shift)

                plt.savefig(os.path.join(args.image_dir, name + '.png'))
                plt.clf()

            # Write enhanced wave files
            if enh_writer is not None:
                if istft is not None:
                    enh = istft(enh)
                else:
                    enh = enh

                if args.keep_length:
                    if len(org_feats[idx]) < len(enh):
                        # Truncate the frames added by stft padding
                        enh = enh[:len(org_feats[idx])]
                    elif len(org_feats) > len(enh):
                        padwidth = [(0, (len(org_feats[idx]) - len(enh)))] \
                            + [(0, 0)] * (enh.ndim - 1)
                        enh = np.pad(enh, padwidth, mode='constant')

                if args.enh_filetype in ('sound', 'sound.hdf5'):
                    enh_writer[name] = (args.fs, enh)
                else:
                    # Hint: To dump stft_signal, mask or etc,
                    # enh_filetype='hdf5' might be convenient.
                    enh_writer[name] = enh

            if num_images >= args.num_images and enh_writer is None:
                logging.info('Breaking the process.')
                break
Ejemplo n.º 11
0
def main():
    args = get_parser().parse_args()

    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    is_wspecifier = ":" in args.wspecifier_or_wxfilename

    if is_wspecifier:
        if args.spk2utt is not None:
            logging.info("Performing as speaker CMVN mode")
            utt2spk_dict = {}
            with open(args.spk2utt) as f:
                for line in f:
                    spk, utts = line.rstrip().split(None, 1)
                    for utt in utts.split():
                        utt2spk_dict[utt] = spk

            def utt2spk(x):
                return utt2spk_dict[x]

        else:
            logging.info("Performing as utterance CMVN mode")

            def utt2spk(x):
                return x

        if args.out_filetype == "npy":
            logging.warning("--out-filetype npy is allowed only for "
                            "Global CMVN mode, changing to hdf5")
            args.out_filetype = "hdf5"

    else:
        logging.info("Performing as global CMVN mode")
        if args.spk2utt is not None:
            logging.warning("spk2utt is not used for global CMVN mode")

        def utt2spk(x):
            return None

        if args.out_filetype == "hdf5":
            logging.warning("--out-filetype hdf5 is not allowed for "
                            "Global CMVN mode, changing to npy")
            args.out_filetype = "npy"

    if args.preprocess_conf is not None:
        preprocessing = Transformation(args.preprocess_conf)
        logging.info("Apply preprocessing: {}".format(preprocessing))
    else:
        preprocessing = None

    # Calculate stats for each speaker
    counts = {}
    sum_feats = {}
    square_sum_feats = {}

    idx = 0
    for idx, (utt, matrix) in enumerate(
            file_reader_helper(args.rspecifier, args.in_filetype), 1):
        if is_scipy_wav_style(matrix):
            # If data is sound file, then got as Tuple[int, ndarray]
            rate, matrix = matrix
        if preprocessing is not None:
            matrix = preprocessing(matrix, uttid_list=utt)

        spk = utt2spk(utt)

        # Init at the first seen of the spk
        if spk not in counts:
            counts[spk] = 0
            feat_shape = matrix.shape[1:]
            # Accumulate in double precision
            sum_feats[spk] = np.zeros(feat_shape, dtype=np.float64)
            square_sum_feats[spk] = np.zeros(feat_shape, dtype=np.float64)

        counts[spk] += matrix.shape[0]
        sum_feats[spk] += matrix.sum(axis=0)
        square_sum_feats[spk] += (matrix**2).sum(axis=0)
    logging.info("Processed {} utterances".format(idx))
    assert idx > 0, idx

    cmvn_stats = {}
    for spk in counts:
        feat_shape = sum_feats[spk].shape
        cmvn_shape = (2, feat_shape[0] + 1) + feat_shape[1:]
        _cmvn_stats = np.empty(cmvn_shape, dtype=np.float64)
        _cmvn_stats[0, :-1] = sum_feats[spk]
        _cmvn_stats[1, :-1] = square_sum_feats[spk]

        _cmvn_stats[0, -1] = counts[spk]
        _cmvn_stats[1, -1] = 0.0

        # You can get the mean and std as following,
        # >>> N = _cmvn_stats[0, -1]
        # >>> mean = _cmvn_stats[0, :-1] / N
        # >>> std = np.sqrt(_cmvn_stats[1, :-1] / N - mean ** 2)

        cmvn_stats[spk] = _cmvn_stats

    # Per utterance or speaker CMVN
    if is_wspecifier:
        with file_writer_helper(args.wspecifier_or_wxfilename,
                                filetype=args.out_filetype) as writer:
            for spk, mat in cmvn_stats.items():
                writer[spk] = mat

    # Global CMVN
    else:
        matrix = cmvn_stats[None]
        if args.out_filetype == "npy":
            np.save(args.wspecifier_or_wxfilename, matrix)
        elif args.out_filetype == "mat":
            # Kaldi supports only matrix or vector
            kaldiio.save_mat(args.wspecifier_or_wxfilename, matrix)
        else:
            raise RuntimeError("Not supporting: --out-filetype {}".format(
                args.out_filetype))
Ejemplo n.º 12
0
def test_KaldiReader(tmpdir, filetype):
    ark = str(tmpdir.join("a.foo"))
    scp = str(tmpdir.join("a.scp"))
    fs = 16000

    with file_writer_helper(
        wspecifier=f"ark,scp:{ark},{scp}",
        filetype=filetype,
        write_num_frames="ark,t:out.txt",
        compress=False,
        compression_method=2,
        pcm_format="wav",
    ) as writer:

        if "sound" in filetype:
            aaa = np.random.randint(-10, 10, 100, dtype=np.int16)
            bbb = np.random.randint(-10, 10, 50, dtype=np.int16)
        else:
            aaa = np.random.randn(10, 10)
            bbb = np.random.randn(13, 5)
        if "sound" in filetype:
            writer["aaa"] = fs, aaa
            writer["bbb"] = fs, bbb
        else:
            writer["aaa"] = aaa
            writer["bbb"] = bbb
        valid = {"aaa": aaa, "bbb": bbb}

    # 1. Test ark read
    if filetype != "sound":
        for key, value in file_reader_helper(
            f"ark:{ark}", filetype=filetype, return_shape=False
        ):
            if "sound" in filetype:
                assert_scipy_wav_style(value)
                value = value[1]
            np.testing.assert_array_equal(value, valid[key])
    # 2. Test scp read
    for key, value in file_reader_helper(
        f"scp:{scp}", filetype=filetype, return_shape=False
    ):
        if "sound" in filetype:
            assert_scipy_wav_style(value)
            value = value[1]
        np.testing.assert_array_equal(value, valid[key])

    # 3. Test ark shape read
    if filetype != "sound":
        for key, value in file_reader_helper(
            f"ark:{ark}", filetype=filetype, return_shape=True
        ):
            if "sound" in filetype:
                value = value[1]
            np.testing.assert_array_equal(value, valid[key].shape)
    # 4. Test scp shape read
    for key, value in file_reader_helper(
        f"scp:{scp}", filetype=filetype, return_shape=True
    ):
        if "sound" in filetype:
            value = value[1]
        np.testing.assert_array_equal(value, valid[key].shape)
Ejemplo n.º 13
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected in the list: " +
                        ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS))
    parser.add_argument('--filetype',
                        type=str,
                        default='mat',
                        choices=['mat', 'hdf5'],
                        help='Specify the file format for output. '
                        '"mat" is the matrix format in kaldi')
    parser.add_argument('--compress',
                        type=bool,
                        default=False,
                        help='Save in compressed format')
    parser.add_argument(
        '--compression-method',
        type=int,
        default=2,
        help='Specify the method(if mat) or gzip-level(if hdf5)')

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument('input_file', type=str, help='Input file')
    parser.add_argument('wspecifier', type=str, help='Write specifier')
    args = parser.parse_args()

    # Setup CUDA, GPU & distributed training
    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = torch.cuda.device_count()
    args.device = device

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)

    # Set seed
    set_seed(args)

    args.model_type = args.model_type.lower()

    if args.model_type == 'sbert':
        transformer = SentenceTransformer(args.model_name_or_path)
        examples = read_examples(args.input_file)
        embeddings = transformer.encode([e.text for e in examples])

        with file_writer_helper(
                args.wspecifier,
                filetype=args.filetype,
                compress=args.compress,
                compression_method=args.compression_method) as writer:
            for i in range(len(examples)):
                writer[examples[i].unique_id] = embeddings[i]
    else:
        config_class, model_class, tokenizer_class = MODEL_CLASSES[
            args.model_type]
        config = config_class.from_pretrained(args.model_name_or_path)
        tokenizer = tokenizer_class.from_pretrained(
            args.model_name_or_path, do_lower_case=args.do_lower_case)
        transformer = model_class.from_pretrained(
            args.model_name_or_path,
            from_tf=bool('.ckpt' in args.model_name_or_path),
            config=config)

        transformer.eval()
        transformer.to(args.device)

        with torch.no_grad():
            load_and_convert_examples(args, tokenizer, transformer)

    logger.info('Done converting {} to {}'.format(args.input_file,
                                                  args.wspecifier))