Example #1
0
 def load_fields(vocab):
     vocab = dict(vocab)
     fields = TableDataset.get_fields()
     for k, v in vocab.items():
         v.stoi = defaultdict(lambda: 0, v.stoi)
         fields[k].vocab = v
     return fields
Example #2
0
 def load_fields(vocab):
     vocab = dict(vocab)
     fields = ONMTDataset.get_fields(len(ONMTDataset.collect_features(vocab)))
     for k, v in vocab.items():
         # Hack. Can't pickle defaultdict :(
         v.stoi = defaultdict(lambda: 0, v.stoi)
         fields[k].vocab = v
     return fields
Example #3
0
    def load_fields(vocab: list):
        vocab = dict(vocab)
        fields = TableDataset.get_fields()

        for k, v in vocab.items():
            # Hack. Can't pickle defaultdict :(
            v.stoi = defaultdict(lambda: 0, v.stoi)
            fields[k].vocab = v

        return fields
Example #4
0
def load_fields(vocab):
    vocab = dict(vocab)
    n_src_features = len(collect_features(vocab, 'src'))
    n_tgt_features = len(collect_features(vocab, 'tgt'))
    fields = get_fields(n_src_features, n_tgt_features)
    for k, v in vocab.items():
        # Hack. Can't pickle defaultdict :(
        v.stoi = defaultdict(lambda: 0, v.stoi)
        fields[k].vocab = v
    return fields
Example #5
0
def load_fields_from_vocab(vocab, data_type="text"):
    """
    Load Field objects from `vocab.pt` file.
    """
    vocab = dict(vocab)
    fields = get_fields(data_type)
    for k, v in vocab.items():
        # Hack. Can't pickle defaultdict :(
        v.stoi = defaultdict(lambda: 0, v.stoi)
        fields[k].vocab = v
    return fields
Example #6
0
def load_glove_fine_tuned(args, get_only_dict=False):
    """
    :return: return dict if get_only_dict is True, otherwise return Vectors
    """

    _cache = os.path.dirname(args.word_embeddings)
    _name = os.path.basename(args.word_embeddings)

    assert os.path.join(_cache, _name) == args.word_embeddings

    logger.info(" * load vocab from [%s]" % args.vocab_file)
    logger.info(" * load pre-trained glove from [%s]" % args.pt_embeddings)
    logger.info(" * load fine-tuned glove from [%s]" % args.word_embeddings)

    glove_emb = load_orig_glove(args.pt_embeddings)

    ft_glove_emb_arr = pickle.load(open(args.word_embeddings, "rb"))
    vocab = pickle.load(open(args.vocab_file, "rb"))

    ft_glove_emb = {w: ft_glove_emb_arr[i] for w, i in vocab.items()}

    for w in tqdm(ft_glove_emb,
                  desc="Mixing embeddings: (%.2f ft, %.2f pt)" %
                  (args.ft_factor, args.pt_factor)):
        if w not in glove_emb:
            glove_emb[w] = ft_glove_emb[w]
        else:
            glove_emb[w] = args.ft_factor * ft_glove_emb[
                w] + args.pt_factor * glove_emb[w]

    if get_only_dict:
        logger.info(" * returning emb dict")
        return glove_emb

    else:
        logger.info(" * returning torchtext.vocab.Vectors")

        # save emb_dict as .pt
        itos = list(glove_emb.keys())
        stoi = {}
        vectors = {}

        for i, w in tqdm(enumerate(glove_emb),
                         total=len(glove_emb),
                         desc="Construct stoi and vectors"):
            stoi[w] = i
            vectors[i] = torch.FloatTensor(glove_emb[w])

        dim = len(vectors[0])
        torch.save([itos, stoi, vectors, dim],
                   os.path.join(_cache, _name + ".pt"))

        # len(vocab) x dim
        return torchtext.vocab.Vectors(name=_name, cache=_cache)
Example #7
0
def load_fields_from_vocab(vocab, data_type="text"):
    """
    Load Field objects from `vocab.pt` file.
    """
    vocab = dict(vocab)
    n_src_features = len(collect_features(vocab, 'src'))
    n_tgt_features = len(collect_features(vocab, 'tgt'))
    fields = get_fields(data_type, n_src_features, n_tgt_features)
    for k, v in vocab.items():
        # Hack. Can't pickle defaultdict :(
        v.stoi = defaultdict(lambda: 0, v.stoi)
        fields[k].vocab = v
    return fields
Example #8
0
def load_fields_from_vocab(vocab, data_type="text"):
    """
    vocab: a list of (field name, torchtext.vocab.Vocab) pairs
    data_type: text, img, or audio
    returns: a dictionary whose keys are the field names and whose values
             are field objects with the vocab set to the corresponding vocab
             object from the input.
    """
    vocab = dict(vocab)
    n_src_features = len(collect_features(vocab, 'src'))
    n_tgt_features = len(collect_features(vocab, 'tgt'))
    fields = get_fields(data_type, n_src_features, n_tgt_features)
    for k, v in vocab.items():
        fields[k].vocab = v
    return fields
Example #9
0
def load_fields_from_vocab(vocab, data_type="text"):
    """
    Load Field objects from `vocab.pt` file.
    """
    vocab = dict(vocab)
    n_src_features = len(collect_features(vocab, 'src'))
    n_qa_features = len(collect_features(vocab, 'qa'))
    n_tgt_features = len(collect_features(vocab, 'tgt'))
    fields = get_fields(n_src_features, n_qa_features, n_tgt_features, data_type)
    for k, v in vocab.items():
        # Hack. Can't pickle defaultdict :(
        v.stoi = defaultdict(lambda: 0, v.stoi)
        fields[k].vocab = v
        if isinstance(fields[k], NestedField):
            fields[k].nesting_field.vocab = v
    return fields
Example #10
0
def load_fields_from_vocab(vocab, data_type="text"):
    """
    Load Field objects from `vocab.pt` file.  ##stefan: input vocab: list of tuples
    """
    #   print("prevocab", vocab[:3]) #test
    vocab = dict(vocab)
    #   print("postvocab", vocab)  #test
    n_src_features = len(collect_features(vocab, 'src'))
    n_tgt_features = len(collect_features(vocab, 'tgt'))
    #   print('src', n_src_features, 'tgt', n_tgt_features)  #test
    fields = get_fields(data_type, n_src_features, n_tgt_features)
    #  print('1fields', fields)  #test
    for k, v in vocab.items():
        # Hack. Can't pickle defaultdict :(
        v.stoi = defaultdict(lambda: 0, v.stoi)
        fields[k].vocab = v

#  print('2fields', fields)
    return fields
Example #11
0
def load_fields_from_vocab(vocab, data_type="text"):
    """
    Load Field objects from `vocab.pt` file.
    """
    vocab = dict(vocab)
    n_src_features = len(collect_features(vocab, 'src'))
    n_tgt_features = len(collect_features(vocab, 'tgt'))
    fields = get_fields(data_type, n_src_features, n_tgt_features)
    fields['ref_src'] = fields['src']
    fields['ref_tgt'] = fields['tgt']
    fields['ref_src_feat_0'] = fields['src_feat_0']
    fields['ref_src_feat_1'] = fields['src_feat_1']
    if 'src_map' in fields:
        fields['ref_src_map'] = fields['src_map']
    if 'alignment' in fields:
        fields['ref_alignment'] = fields['alignment']
    if 'indices' in fields:
        fields['ref_indices'] = fields['indices']
    for k, v in vocab.items():
        # Hack. Can't pickle defaultdict :(
        v.stoi = defaultdict(lambda: 0, v.stoi)
        fields[k].vocab = v
    return fields