Пример #1
0
def get_fields(data_type, n_src_features, n_tgt_features):
    """
    Args:
        data_type: type of the source input. Options are [text|img|audio].
        n_src_features: the number of source features to
            create `torchtext.data.Field` for.
        n_tgt_features: the number of target features to
            create `torchtext.data.Field` for.

    Returns:
        A dictionary whose keys are strings and whose values are the
        corresponding Field objects.
    """
    if data_type == 'text':
        return TextDataset.get_fields(n_src_features, n_tgt_features)
    elif data_type == 'img':
        return ImageDataset.get_fields(n_src_features, n_tgt_features)
    elif data_type == 'audio':
        return AudioDataset.get_fields(n_src_features, n_tgt_features)
    elif data_type == 'gcn':
        return GCNDataset.get_fields(n_src_features, n_tgt_features)
Пример #2
0
def main():
    # Todo: Load checkpoint if we resume from a previous training.
    if opt.train_from:  # opt.train_from defaults 'False'.
        print('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
        opt.start_epoch = checkpoint['epoch'] + 1
    else:
        checkpoint = None
        model_opt = opt

    train_dataset = lazily_load_dataset('train')
    ex_generator = next(train_dataset)
    # # {'indices': 0,
    # #  'src': None, # will not be used. should be removed when preparing data.
    # #  'src_audio': FloatTensor,
    # #  'src_path': wav path, # will not be used. should be removed when preparing data.
    # #  'src_text': tuple,
    # #  'tgt': tuple
    # # }
    # For debug.
    # ex=ex_generator[0]
    # getattr(ex,'src_audio',None)
    # getattr(ex,'src_text',None)
    # getattr(ex,'tgt',None)
    pass

    # load vocab
    vocabs = torch.load(opt.data + '.vocab.pt')  # 'src_text', 'tgt'
    vocabs = dict(vocabs)
    pass
    # get fields, we attempt to use dict to store fields for different encoders(source data).
    text_fields = TextDataset.get_fields(
        0, 0)  # Here we set number of src_features and tgt_features to 0.
    # Actually, we can use these features, but it need more modifications.

    audio_fields = AudioDataset.get_fields(0, 0)

    # fields['src_text'] = fields['src']  # Copy key from 'src' to 'src_text'. for assigning the field for text type input.
    # the field for audio type input will not be made, i.e., fields['src_audio']=audio_fields['src'].
    # Because it will not be used next.

    for k, v in vocabs.items():
        v.stoi = defaultdict(lambda: 0, v.stoi)
        if k == 'src_text':
            text_fields['src'].vocab = v
        else:
            text_fields['tgt'].vocab = v
            audio_fields['tgt'].vocab = v

    text_fields = dict([(k, f) for (k, f) in text_fields.items()
                        if k in ex_generator[0].__dict__
                        ])  # 'indices', 'src', 'src_text', 'tgt'
    audio_fields = dict([(k, f) for (k, f) in audio_fields.items()
                         if k in ex_generator[0].__dict__])

    print(' * vocabulary size. text source = %d; target = %d' %
          (len(text_fields['src'].vocab), len(text_fields['tgt'].vocab)))
    print(' * vocabulary size. audio target = %d' %
          len(audio_fields['tgt'].vocab))

    fields_dict = {'text': text_fields, 'audio': audio_fields}
    pass

    # Build model.
    model = build_multiencoder_model(
        model_opt, opt, fields_dict)  # TODO: support using 'checkpoint'.
    tally_parameters(model)
    check_save_model_path()

    # Build optimizer.
    optim = build_optim(model)  # TODO: support using 'checkpoint'.

    # Do training.
    train_model(model,
                fields_dict,
                optim,
                data_type='multi',
                model_opt=model_opt)