コード例 #1
0
ファイル: IO.py プロジェクト: lbaermann/opennmt-py
def build_dataset(fields,
                  data_type,
                  src_path,
                  tgt_path,
                  src_dir=None,
                  second_data_type=None,
                  second_src_path=None,
                  src_seq_length=0,
                  tgt_seq_length=0,
                  src_seq_length_trunc=0,
                  tgt_seq_length_trunc=0,
                  dynamic_dict=True,
                  sample_rate=0,
                  window_size=0,
                  window_stride=0,
                  window=None,
                  normalize_audio=True,
                  use_filter_pred=True,
                  file_to_tensor_fn=None):

    use_second_modality = second_data_type is not None
    if use_second_modality:
        assert data_type == 'text'  # Only implemented for primary input type text
        # Second data type should not be text. One could simply append his secondary text
        # to the primary input.
        assert second_data_type != 'text', 'second_data_type cannot be text.'
        assert second_src_path is not None and src_dir is not None, \
            'If second_data_type is set, second_src_path as well as src_dir needs to be present'

    # Build src/tgt examples iterator from corpus files, also extract
    # number of features.
    src_examples_iter, num_src_feats = \
        _make_examples_nfeats_tpl(data_type, src_path, src_dir,
                                  src_seq_length_trunc, sample_rate,
                                  window_size, window_stride,
                                  window, normalize_audio,
                                  file_to_tensor_fn=file_to_tensor_fn)
    if use_second_modality:
        src2_examples_iter, num_src2_feats = \
            _make_examples_nfeats_tpl(second_data_type, second_src_path, src_dir,
                                      src_seq_length_trunc, sample_rate,
                                      window_size, window_stride,
                                      window, normalize_audio, side='src2',
                                      file_to_tensor_fn=file_to_tensor_fn)

    # For all data types, the tgt side corpus is in form of text.
    tgt_examples_iter, num_tgt_feats = \
        TextDataset.make_text_examples_nfeats_tpl(
            tgt_path, tgt_seq_length_trunc, "tgt")

    if use_second_modality:
        dataset = MultiModalDataset(fields,
                                    src_examples_iter,
                                    src2_examples_iter,
                                    second_data_type,
                                    tgt_examples_iter,
                                    num_src_feats,
                                    num_src2_feats,
                                    num_tgt_feats,
                                    src_seq_length=src_seq_length,
                                    tgt_seq_length=tgt_seq_length,
                                    use_filter_pred=use_filter_pred)
    elif data_type == 'text':
        dataset = TextDataset(fields,
                              src_examples_iter,
                              tgt_examples_iter,
                              num_src_feats,
                              num_tgt_feats,
                              src_seq_length=src_seq_length,
                              tgt_seq_length=tgt_seq_length,
                              dynamic_dict=dynamic_dict,
                              use_filter_pred=use_filter_pred)

    elif data_type == 'img':
        dataset = ImageDataset(fields,
                               src_examples_iter,
                               tgt_examples_iter,
                               num_src_feats,
                               num_tgt_feats,
                               tgt_seq_length=tgt_seq_length,
                               use_filter_pred=use_filter_pred)

    elif data_type == 'audio':
        dataset = AudioDataset(fields,
                               src_examples_iter,
                               tgt_examples_iter,
                               num_src_feats,
                               num_tgt_feats,
                               tgt_seq_length=tgt_seq_length,
                               sample_rate=sample_rate,
                               window_size=window_size,
                               window_stride=window_stride,
                               window=window,
                               normalize_audio=normalize_audio,
                               use_filter_pred=use_filter_pred)

    return dataset
コード例 #2
0
def main():
    # Todo: Load checkpoint if we resume from a previous training.
    if opt.train_from:  # opt.train_from defaults 'False'.
        print('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
        opt.start_epoch = checkpoint['epoch'] + 1
    else:
        checkpoint = None
        model_opt = opt

    train_dataset = lazily_load_dataset('train')
    ex_generator = next(train_dataset)
    # # {'indices': 0,
    # #  'src': None, # will not be used. should be removed when preparing data.
    # #  'src_audio': FloatTensor,
    # #  'src_path': wav path, # will not be used. should be removed when preparing data.
    # #  'src_text': tuple,
    # #  'tgt': tuple
    # # }
    # For debug.
    # ex=ex_generator[0]
    # getattr(ex,'src_audio',None)
    # getattr(ex,'src_text',None)
    # getattr(ex,'tgt',None)
    pass

    # load vocab
    vocabs = torch.load(opt.data + '.vocab.pt')  # 'src_text', 'tgt'
    vocabs = dict(vocabs)
    pass
    # get fields, we attempt to use dict to store fields for different encoders(source data).
    text_fields = TextDataset.get_fields(
        0, 0)  # Here we set number of src_features and tgt_features to 0.
    # Actually, we can use these features, but it need more modifications.

    audio_fields = AudioDataset.get_fields(0, 0)

    # fields['src_text'] = fields['src']  # Copy key from 'src' to 'src_text'. for assigning the field for text type input.
    # the field for audio type input will not be made, i.e., fields['src_audio']=audio_fields['src'].
    # Because it will not be used next.

    for k, v in vocabs.items():
        v.stoi = defaultdict(lambda: 0, v.stoi)
        if k == 'src_text':
            text_fields['src'].vocab = v
        else:
            text_fields['tgt'].vocab = v
            audio_fields['tgt'].vocab = v

    text_fields = dict([(k, f) for (k, f) in text_fields.items()
                        if k in ex_generator[0].__dict__
                        ])  # 'indices', 'src', 'src_text', 'tgt'
    audio_fields = dict([(k, f) for (k, f) in audio_fields.items()
                         if k in ex_generator[0].__dict__])

    print(' * vocabulary size. text source = %d; target = %d' %
          (len(text_fields['src'].vocab), len(text_fields['tgt'].vocab)))
    print(' * vocabulary size. audio target = %d' %
          len(audio_fields['tgt'].vocab))

    fields_dict = {'text': text_fields, 'audio': audio_fields}
    pass

    # Build model.
    model = build_multiencoder_model(
        model_opt, opt, fields_dict)  # TODO: support using 'checkpoint'.
    tally_parameters(model)
    check_save_model_path()

    # Build optimizer.
    optim = build_optim(model)  # TODO: support using 'checkpoint'.

    # Do training.
    train_model(model,
                fields_dict,
                optim,
                data_type='multi',
                model_opt=model_opt)