Esempio n. 1
0
def make_word_embeddings(opt, word_dict, fields):
    word_padding_idx = word_dict.stoi[table.IO.PAD_WORD]
    num_word = len(word_dict)
    emb_word = nn.Embedding(num_word,
                            opt.word_vec_size,
                            padding_idx=word_padding_idx)
    return emb_word
    if len(opt.pre_word_vecs) > 0:
        vectors = torchtext.vocab.GloVe(name="840B",
                                        cache=opt.pre_word_vecs,
                                        dim=str(opt.word_vec_size))
        fields["src"].vocab.load_vectors(vectors)
        emb_word.weight.data.copy_(fields["src"].vocab.vectors)

    if opt.fix_word_vecs:
        # <unk> is 0
        num_special = len(table.IO.special_token_list)
        # zero vectors in the fixed embedding (emb_word)
        emb_word.weight.data[:num_special].zero_()
        emb_special = nn.Embedding(num_special,
                                   opt.word_vec_size,
                                   padding_idx=word_padding_idx)
        emb = PartUpdateEmbedding(num_special, emb_special, emb_word)
        return emb
    else:
        return emb_word
Esempio n. 2
0
def make_word_embeddings(opt, word_dict, fields):
    word_padding_idx = word_dict.stoi[table.IO.PAD_WORD]
    num_word = len(word_dict)
    emb_word = nn.Embedding(num_word,
                            opt.word_vec_size,
                            padding_idx=word_padding_idx)

    if len(opt.pre_word_vecs) > 0:
        if opt.word_vec_size == 150:
            dim_list = ['100', '50']
        elif opt.word_vec_size == 250:
            dim_list = ['200', '50']
        else:
            dim_list = [
                str(opt.word_vec_size),
            ]
        vectors = [
            torchtext.vocab.GloVe(name="6B", cache=opt.pre_word_vecs, dim=it)
            for it in dim_list
        ]
        word_dict.load_vectors(vectors)
        emb_word.weight.data.copy_(word_dict.vectors)

    if opt.fix_word_vecs:
        # <unk> is 0
        num_special = len(table.IO.special_token_list)
        # zero vectors in the fixed embedding (emb_word)
        emb_word.weight.data[:num_special].zero_()
        emb_special = nn.Embedding(num_special,
                                   opt.word_vec_size,
                                   padding_idx=word_padding_idx)
        emb = PartUpdateEmbedding(num_special, emb_special, emb_word)
        return emb
    else:
        return emb_word
Esempio n. 3
0
def make_word_embeddings(args, vocab: torchtext.vocab.Vocab):
    word_padding_idx = vocab.stoi[table.IO.PAD_WORD]
    num_word = len(vocab)
    emb_word = nn.Embedding(num_word,
                            args.word_emb_size,
                            padding_idx=word_padding_idx)

    logger.info(" * using embeddings [%s]" % args.word_embeddings)

    if args.word_embeddings != '':  # TODO: might get rid of this check?

        # load custom embeddings
        if args.use_custom_embeddings:
            logger.info(" * loading custom embeddings")
            vocab.load_vectors(
                vectors=[load_glove_fine_tuned(args, get_only_dict=False)])
            emb_word.weight.data.copy_(vocab.vectors)

        else:
            logger.info(" * using default embeddings")

            if args.word_emb_size == 150:
                dim_list = ['100', '50']
            elif args.word_emb_size == 250:
                dim_list = ['200', '50']
            else:
                dim_list = [
                    str(args.word_emb_size),
                ]

            vectors = [
                torchtext.vocab.GloVe(name="6B",
                                      cache=args.word_embeddings,
                                      dim=it) for it in dim_list
            ]
            vocab.load_vectors(vectors)
            emb_word.weight.data.copy_(vocab.vectors)
    # ---

    if args.fix_word_vecs:
        # <unk> is 0
        num_special = len(table.IO.SPECIAL_TOKEN_LIST)
        # zero vectors in the fixed embedding (emb_word)
        emb_word.weight.data[:num_special].zero_()
        emb_special = nn.Embedding(num_special,
                                   args.word_emb_size,
                                   padding_idx=word_padding_idx)
        emb = PartUpdateEmbedding(num_special, emb_special, emb_word)
        return emb
    else:
        return emb_word
Esempio n. 4
0
def make_word_embeddings(opt, word_dict, fields):
    word_padding_idx = word_dict.stoi[table.IO.PAD_WORD]
    num_word = len(word_dict)
    print(opt.word_vec_size)
    # in pytorch, nn.Embedding is not a nn. nn.Embedding is a num_word*opt.word_vec_size matrix
    emb_word = nn.Embedding(num_word,
                            opt.word_vec_size,
                            padding_idx=word_padding_idx)

    if len(opt.pre_word_vecs) > 0:
        # torchtext.vocab.GloVe(): if there is a file name:glove.{name}.{dim}d.txt.pt it will load it
        # If there is not *.pt file, it will take glove.{name}.{dim}d.txt instead
        vectors = torchtext.vocab.GloVe(name="840B",
                                        cache=opt.pre_word_vecs,
                                        dim=str(opt.word_vec_size))

        # the emb_word hold the same dimension as the 840B
        # Although there are a lot of word vectors in GloVe(), there are many words from dataset not in GloVe()
        # such as: state/territory, text/background etc. These word vectors that not in GloVe() will be [0,0,...,0] as same as special token.
        fields["src"].vocab.load_vectors(vectors)
        emb_word.weight.data.copy_(
            fields["src"].vocab.vectors
        )  # we define the emb_word.weight is the same as the fields["src"].vocab.vectors

    if opt.fix_word_vecs:  # fix_word_vecs will be true
        # <unk> is 0
        num_special = len(table.IO.special_token_list)
        # zero vectors in the fixed embedding (emb_word). Actually it is zero in GloVe vector since the token is not a word
        emb_word.weight.data[:num_special].zero_()

        emb_special = nn.Embedding(num_special,
                                   opt.word_vec_size,
                                   padding_idx=word_padding_idx)
        emb = PartUpdateEmbedding(num_special, emb_special, emb_word)
        return emb
    else:
        return emb_word