Пример #1
0
    def __init__(self,
                 *args,
                 pca_dim=10,
                 batch_size=32,
                 device=None,
                 bos="<bos>",
                 eos="<eos>",
                 unk="<unk>",
                 **kwargs):
        super(PyTorchPCASearcher, self).__init__(*args, **kwargs)
        self.batch_size = batch_size
        self.device = device
        self.bos = bos
        self.eos = eos
        self.unk = unk
        self.pca_dim = pca_dim

        if self.device is None:
            self.device = torch.device("cpu")

        self.vocab = utils.Vocabulary()
        self.vocab.add(bos)
        self.vocab.add(eos)
        self.vocab.add(unk)
        utils.populate_vocab(self.words, self.vocab)
        self.unk_idx = self.vocab.f2i[self.unk]

        self.sparse = importlib.import_module("scipy.sparse")
        self.decomp = importlib.import_module("sklearn.decomposition")
        self.sents_csr = self.sparse.vstack(
            [self.to_csr(s) for s in self.sents])
        self.pca = self.get_pca(self.sents_csr)
        self.sents_pca = self.pca.transform(self.sents_csr)
        self.sents_tensor = torch.Tensor(self.sents_pca).to(self.device)
Пример #2
0
    def _load_data(self):
        reader = TextFileReader()
        with reader(self.path) as f:
            self.data = [line.rstrip().split() for line in f]
            if self.pad_eos is not None:
                self.data = [sent + [self.pad_eos] for sent in self.data]
            if self.pad_bos is not None:
                self.data = [[self.pad_bos] + sent for sent in self.data]

        if self.vocab is None:
            self.vocab = utils.Vocabulary()
            utils.populate_vocab(words=[w for s in self.data for w in s],
                                 vocab=self.vocab,
                                 cutoff=self.vocab_limit)
            self.vocab.add("<unk>")
        self.unk_idx = self.vocab.f2i.get(self.unk)
Пример #3
0
def prepare_model(args, vocabs):
    mdl = model.create_model(args, vocabs)
    mdl.reset_parameters()
    ckpt = torch.load(args.ckpt_path)
    mdl.load_state_dict(ckpt)
    if args.expand_vocab:
        mdl_vocab = vocabs[0]
        mdl_emb = mdl.embeds[0].weight
        emb = embeds.get_embeddings(args)
        emb.preload()
        emb = {w: v for w, v in emb}
        for rword in [args.bos, args.eos, args.unk]:
            emb[rword] = mdl_emb[mdl_vocab.f2i.get(rword)].detach().numpy()
        vocab = utils.Vocabulary()
        utils.populate_vocab(emb.keys(), vocab)
        mdl.embeds[0] = embedding.BasicEmbedding(vocab_size=len(vocab),
                                                 dim=mdl.word_dim,
                                                 allow_padding=True)
        embeds._load_embeddings(mdl.embeds[0], vocab, emb.items())
    else:
        vocab = vocabs[0]
    return mdl, vocab
Пример #4
0
    def _load_data(self):
        reader = TextFileReader()
        self.data = []
        for path in self.paths:
            with reader(path) as f:
                data = [line.rstrip().split() for line in f]
                if self.pad_eos is not None:
                    data = [sent + [self.pad_eos] for sent in data]
                if self.pad_bos is not None:
                    data = [[self.pad_bos] + sent for sent in data]
                self.data.append(data)
        self.data = list(zip(*self.data))

        for i in range(len(self.vocabs)):
            vocab = self.vocabs[i]
            if vocab is None:
                vocab = utils.Vocabulary()
                utils.populate_vocab(
                    words=[w for s in self.data for w in s[i]],
                    vocab=vocab,
                    cutoff=self.vocab_limit)
                vocab.add("<unk>")
                self.vocabs[i] = vocab
            self.unk_idxs[i] = vocab.f2i.get(self.unk)
Пример #5
0
    def __init__(self,
                 *args,
                 batch_size=32,
                 device=None,
                 bos="<bos>",
                 eos="<eos>",
                 unk="<unk>",
                 **kwargs):
        super(PyTorchNNSearcher, self).__init__(*args, **kwargs)
        self.batch_size = batch_size
        self.device = device
        self.bos = bos
        self.eos = eos
        self.unk = unk

        if self.device is None:
            self.device = torch.device("cpu")
        self.vocab = utils.Vocabulary()
        self.vocab.add(bos)
        self.vocab.add(eos)
        self.vocab.add(unk)
        utils.populate_vocab(self.words, self.vocab)
        self.tensors = torch.stack([self.tensorize_bow(s) for s in self.sents])
        self.tensors = self.tensors.to(self.device)
Пример #6
0
                                     label_vocabs=label_vocab,
                                     batch_size=batch_size,
                                     shuffle=shuffle,
                                     tensor_lens=True,
                                     num_workers=1,
                                     pin_memory=True)


if __name__ == "__main__":
    input_vocabs = []
    input = input_path
    vocab = utils.Vocabulary()
    words = utils.FileReader(input).words()
    vocab.add("<pad>")
    vocab.add("<unk>")
    utils.populate_vocab(words, vocab)
    input_vocabs.append(vocab)
    # print(input_vocabs)[<utils.Vocabulary object at 0x7fa839f5a0b8>]

    label_vocab = utils.Vocabulary()
    words = utils.FileReader(label_path).words()
    label_vocab.add("START")
    label_vocab.add("END")
    utils.populate_vocab(words, label_vocab)

    crf = M.CRF(len(label_vocab))
    model = M.LSTMCRF(crf=crf,
                      vocab_sizes=[len(v) for v in input_vocabs],
                      word_dims=word_dim,
                      hidden_dim=lstm_dim,
                      layers=lstm_layers,
Пример #7
0
def main(args):
    logging.basicConfig(level=logging.INFO)
    check_arguments(args)

    logging.info("Creating vocabulary...")
    input_vocabs = []

    for input in args.input_path:
        vocab = utils.Vocabulary()
        words = utils.FileReader(input).words()
        vocab.add("<pad>")
        vocab.add("<unk>")
        utils.populate_vocab(words, vocab)
        input_vocabs.append(vocab)
    # print(input_vocabs)[<utils.Vocabulary object at 0x7fa839f5a0b8>]

    label_vocab = utils.Vocabulary()
    words = utils.FileReader(args.label_path).words()
    label_vocab.add("START")
    label_vocab.add("END")
    utils.populate_vocab(words, label_vocab)

    for i, input_vocab in enumerate(input_vocabs):
        vocab_path = os.path.join(args.save_dir,
                                  "vocab-input{}.pkl".format(i + 1))
        pickle.dump(input_vocab, open(vocab_path, "wb"))
    vocab_path = os.path.join(args.save_dir, "vocab-label.pkl")
    pickle.dump(label_vocab, open(vocab_path, "wb"))

    logging.info("Initializing model...")
    crf = M.CRF(len(label_vocab))
    print('args.word_dim==',args.word_dim,type(args.word_dim))
    model = M.LSTMCRF(
        crf=crf,
        vocab_sizes=[len(v) for v in input_vocabs],
        word_dims=args.word_dim,
        hidden_dim=args.lstm_dim,
        layers=args.lstm_layers,
        dropout_prob=args.dropout_prob,
        bidirectional=args.bidirectional
    )
    model.reset_parameters()
    if args.gpu:
        gpu_main = args.gpu[0]
        model = model.cuda(gpu_main)
    params = sum(np.prod(p.size()) for p in model.parameters())
    logging.info("Number of parameters: {}".format(params))

    logging.info("Loading word embeddings...")
    # for vocab, we_type, we_path, we_freeze, emb in \
    #         zip(input_vocabs, args.wordembed_type, args.wordembed_path,
    #             args.wordembed_freeze, model.embeddings):
    #     if we_type == "glove":
    #         assert we_path is not None
    #         load_glove_embeddings(emb, vocab, we_path)
    #     elif we_type == "fasttext":
    #         assert we_path is not None
    #         assert args.fasttext_path is not None
    #         load_fasttext_embeddings(emb, vocab,
    #                                  fasttext_path=args.fasttext_path,
    #                                  embedding_path=we_path)
    #     elif we_type == "none":
    #         pass
    #     else:
    #         raise ValueError("Unrecognized word embedding "
    #                          "type: {}".format(we_type))
    #
    #     if we_freeze:
    #         emb.weight.requires_grad = False

    # Copying configuration file to save directory if config file is specified.
    if args.config:
        config_path = os.path.join(args.save_dir, os.path.basename(args.config))
        shutil.copy(args.config, config_path)

    def create_dataloader(dataset):
        return D.MultiSentWordDataLoader(
            dataset=dataset,
            input_vocabs=input_vocabs,
            label_vocabs=label_vocab,
            batch_size=args.batch_size,
            shuffle=args.shuffle,
            tensor_lens=True,
            num_workers=len(args.gpu) if args.gpu is not None else 1,
            pin_memory=True
        )

    dataset = D.MultiSentWordDataset(*args.input_path, args.label_path)
    test_dataset = D.MultiSentWordDataset(*args.test_input_path, args.test_label_path)

    if args.val:
        vr = args.val_ratio
        val_dataset, _ = dataset.split(vr, 1-vr, shuffle=args.shuffle)
    else:
        val_dataset = None

    train_dataset = dataset
    train_dataloader = create_dataloader(train_dataset)
    test_dataloader = create_dataloader(test_dataset)

    if val_dataset is not None:
        val_dataloader = create_dataloader(val_dataset)
    else:
        val_dataloader = None
    print(input_vocabs,type(input_vocabs))

    logging.info("Beginning training...")
    trainer = LSTMCRFTrainer(
        sargs=args,
        input_vocabs=input_vocabs,
        label_vocab=label_vocab,
        val_data=val_dataloader,
        model=model,
        epochs=args.epochs,
        gpus=args.gpu
    )

    trainer.train(train_dataloader, data_size=len(train_dataset))
    # trainer.validate()
    logging.info("Beginning testing...")
    # trainer.test(train_dataloader, data_size=len(train_dataset))
    #trainer.test(test_dataloader, data_size=len(test_dataset))
    logging.info("Done!")