示例#1
0
    def __init__(self, device, **kwargs):
        path_to_data = kwargs.pop('path_to_data')
        super().__init__(device, **kwargs)

        # Defining fields
        TEXT = data.ReversibleField(init_token='<sos>',
                                    eos_token='<eos>',
                                    lower=self.lower,
                                    tokenize=tokenizer_ptb,
                                    include_lengths=True,
                                    fix_length=self._fix_length,
                                    batch_first=True)
        train, valid, test = datasets.PennTreebank.splits(TEXT,
                                                          root=path_to_data)

        if self.min_len is not None:
            train.examples = [
                x for x in train.examples if len(x.text) >= self.min_len
            ]
            valid.examples = [
                x for x in valid.examples if len(x.text) >= self.min_len
            ]
            test.examples = [
                x for x in test.examples if len(x.text) >= self.min_len
            ]
        if self.max_len is not None:
            train.examples = [
                x for x in train.examples if len(x.text) <= self.max_len
            ]
            valid.examples = [
                x for x in valid.examples if len(x.text) <= self.max_len
            ]
            test.examples = [
                x for x in test.examples if len(x.text) <= self.max_len
            ]

        if self._fix_length == -1:
            TEXT.fix_length = max([train.max_len, valid.max_len, test.max_len])

        self._train_iter, self._valid_iter, self._test_iter = BucketIterator.splits(
            (train, valid, test),
            batch_sizes=(self.batch_size, self.batch_size, self.batch_size),
            sort_key=lambda x: len(x.text),
            sort_within_batch=True,
            repeat=False,
            device=self.device)

        TEXT.build_vocab(train,
                         vectors=self.emb_dim,
                         vectors_cache=self.path_to_vectors,
                         max_size=self.voc_size,
                         min_freq=self.min_freq)
        self.train_vocab = TEXT.vocab
        self._fix_length = TEXT.fix_length
示例#2
0
    def __init__(self, device, **kwargs):
        server = kwargs.pop('server', 'localhost')
        db_name = kwargs.pop('db')
        data_collection_name = kwargs.pop('data_collection')
        super().__init__(device, **kwargs)

        train_col = f'{data_collection_name}_train'
        val_col = f'{data_collection_name}_validation'
        test_col = f'{data_collection_name}_test'

        FIELD_TEXT = data.ReversibleField(init_token='<sos>',
                                          eos_token='<eos>',
                                          unk_token='<unk>',
                                          tokenize=partial(
                                              tokenizer,
                                              punct=self.punctuation),
                                          batch_first=True,
                                          use_vocab=True,
                                          fix_length=self._fix_length,
                                          include_lengths=True,
                                          lower=self.lower)

        train, valid, test = datasets.Yelp2019.splits(server,
                                                      db_name,
                                                      text_field=FIELD_TEXT,
                                                      train=train_col,
                                                      validation=val_col,
                                                      test=test_col,
                                                      **self.dataset_kwargs)
        if self._fix_length == -1:
            FIELD_TEXT.fix_length = max(
                [train.max_len, valid.max_len, test.max_len])

        self._train_iter, self._valid_iter, self._test_iter = BucketIterator.splits(
            (train, valid, test),
            batch_sizes=(self.batch_size, self.batch_size, len(test)),
            sort_key=lambda x: len(x.text),
            sort_within_batch=True,
            repeat=False,
            device=self.device)
        FIELD_TEXT.build_vocab(train,
                               vectors=self.emb_dim,
                               vectors_cache=self.path_to_vectors,
                               max_size=self.voc_size,
                               min_freq=self.min_freq)
        self.train_vocab = FIELD_TEXT.vocab
        self._fix_length = FIELD_TEXT.fix_length
示例#3
0
文件: loader.py 项目: cesarali/Tyche
    def __init__(self, **kwargs):
        batch_size = kwargs.get('batch_size')
        path_to_data = kwargs.pop('path_to_data')
        path_to_vectors = kwargs.pop('path_to_vectors')
        emb_dim = kwargs.pop('emb_dim')
        voc_size = kwargs.pop('voc_size', None)
        min_freq = kwargs.pop('min_freq', 1)
        fix_len = kwargs.pop('fix_len', None)

        # Defining fields
        TEXT = data.ReversibleField(init_token='<sos>',
                                    eos_token='<eos>',
                                    unk_token='UNK',
                                    lower=True,
                                    tokenize=tokenizer,
                                    include_lengths=True,
                                    fix_length=fix_len,
                                    batch_first=True)
        train, valid, test = datasets.PennTreebank.splits(TEXT,
                                                          root=path_to_data)

        if fix_len == -1:
            TEXT.fix_length = max([train.max_len, valid.max_len, test.max_len])

        self._train_iter, self._valid_iter, self._test_iter = data.BucketIterator.splits(
            (train, valid, test),
            batch_sizes=(batch_size, batch_size, len(test)),
            sort_key=lambda x: len(x.text),
            sort_within_batch=True,
            repeat=False)

        TEXT.build_vocab(train,
                         vectors=emb_dim,
                         vectors_cache=path_to_vectors,
                         max_size=voc_size,
                         min_freq=min_freq)
        self.train_vocab = TEXT.vocab
        self.fix_length = TEXT.fix_length
示例#4
0
    def __init__(self, device, **kwargs):
        path_to_data = kwargs.pop('path_to_data')
        super().__init__(device, **kwargs)

        # Defining fields
        TEXT = data.ReversibleField(init_token='<sos>',
                                    eos_token='<eos>',
                                    tokenize=None,
                                    lower=self.lower,
                                    include_lengths=True,
                                    fix_length=self._fix_length,
                                    batch_first=True)
        train, valid, test = datasets.WikiText103.splits(TEXT,
                                                         root=path_to_data)

        for dataset in [train, valid, test]:
            dataset = _preprocess_wiki(dataset, self.min_len, self.max_len)

        if self._fix_length == -1:
            TEXT.fix_length = max([train.max_len, valid.max_len, test.max_len])

        self._train_iter, self._valid_iter, self._test_iter = BucketIterator.splits(
            (train, valid, test),
            batch_sizes=(self.batch_size, self.batch_size, len(test)),
            sort_key=lambda x: len(x.text),
            sort_within_batch=True,
            repeat=False,
            device=self.device)

        TEXT.build_vocab(train,
                         vectors=self.emb_dim,
                         vectors_cache=self.path_to_vectors,
                         max_size=self.voc_size,
                         min_freq=self.min_freq)
        self.train_vocab = TEXT.vocab
        self._fix_length = TEXT.fix_length
示例#5
0
文件: loader.py 项目: cesarali/Tyche
    def __init__(self, **kwargs):
        batch_size = kwargs.pop('batch_size', 32)
        path_to_vectors = kwargs.pop('path_to_vectors')
        emb_dim = kwargs.pop('emb_dim')
        voc_size = kwargs.pop('voc_size', None)
        min_freq = kwargs.pop('min_freq', 1)
        fix_len = kwargs.pop('fix_len', None)
        bptt_length = kwargs.pop('bptt_len', 20)
        bow_size = kwargs.get('bow_size')
        server = kwargs.pop('server', 'localhost')
        data_collection_name = kwargs.pop('data_collection')

        FIELD_TIME = data.BPTTField(bptt_len=bptt_length,
                                    use_vocab=False,
                                    include_lengths=True,
                                    pad_token=[0, 0, 0],
                                    preprocessing=delta)
        FIELD_BOW = data.BPTTField(bptt_len=bptt_length,
                                   use_vocab=False,
                                   include_lengths=False,
                                   pad_token=csr_matrix((1, bow_size)),
                                   preprocessing=unpack_bow2seq,
                                   postprocessing=expand_bow_vector,
                                   dtype=torch.float32)

        FIELD_TEXT = data.ReversibleField(init_token='<sos>',
                                          eos_token='<eos>',
                                          unk_token='UNK',
                                          tokenize=tokenizer,
                                          batch_first=True,
                                          use_vocab=True,
                                          is_target=True)
        NESTED_TEXT_FIELD = NestedField(FIELD_TEXT,
                                        use_vocab=False,
                                        preprocessing=unpack_text)
        train_col = f'{data_collection_name}_train'
        val_col = f'{data_collection_name}_validation'
        test_col = f'{data_collection_name}_test'
        train, valid, test = datasets.RatebeerBow2Seq.splits(
            server,
            time_field=FIELD_TIME,
            text_field=NESTED_TEXT_FIELD,
            bow_field=FIELD_BOW,
            train=train_col,
            validation=val_col,
            test=test_col,
            **kwargs)

        if fix_len == -1:
            max_len = max([train.max_len, valid.max_len, test.max_len])
            FIELD_TIME.fix_length = max_len
            FIELD_TEXT.fix_length = max_len

        self._train_iter, self._valid_iter, self._test_iter = data.BPTTIterator.splits(
            (train, valid, test),
            batch_sizes=(batch_size, batch_size, len(test)),
            sort_key=lambda x: len(x.time),
            sort_within_batch=True,
            repeat=False,
            bptt_len=bptt_length)
        self.bptt_length = bptt_length
        NESTED_TEXT_FIELD.build_vocab(train,
                                      vectors=emb_dim,
                                      vectors_cache=path_to_vectors,
                                      max_size=voc_size,
                                      min_freq=min_freq)
        self.train_vocab = NESTED_TEXT_FIELD.vocab
        self.fix_length = NESTED_TEXT_FIELD.fix_length
    def __init__(self, device, **kwargs):
        batch_size = kwargs.pop('batch_size')
        path_to_vectors = kwargs.pop('path_to_vectors')
        emb_dim = kwargs.pop('emb_dim')
        voc_size = kwargs.pop('voc_size', None)
        min_freq = kwargs.pop('min_freq')
        fix_len = kwargs.pop('fix_len', None)
        bptt_length = kwargs.pop('bptt_len')
        bow_size = kwargs.get('bow_size')
        server = kwargs.pop('server', 'localhost')
        data_collection_name = kwargs.pop('data_collection')
        self.__t_max = kwargs.pop('t_max')

        train_col = f'{data_collection_name}_train'
        val_col = f'{data_collection_name}_validation'
        test_col = f'{data_collection_name}_test'

        db = MongoClient('mongodb://' + server)['hawkes_text']
        col = db[train_col]
        if self.__t_max is None:
            min_max_values = list(
                col.aggregate([{
                    "$project": {
                        "_id": 0,
                        "time": 1
                    }
                }, {
                    "$unwind": "$time"
                }, {
                    "$group": {
                        "_id": None,
                        "max": {
                            "$max": "$time"
                        },
                        "min": {
                            "$min": "$time"
                        }
                    }
                }, {
                    "$limit": 1
                }]))[0]
            self.__t_min = min_max_values['min']
            self.__t_max = min_max_values['max']
        # part_scale = partial(min_max_scale, min_value=self.__t_min, max_value=self.__t_max)
        FIELD_TIME = data.BPTTField(bptt_length=bptt_length,
                                    use_vocab=False,
                                    fix_length=fix_len,
                                    include_lengths=True,
                                    pad_token=[-1.0, -1.0, -1.0],
                                    preprocessing=partial(delta,
                                                          t_max=self.__t_max),
                                    dtype=torch.float64)
        FIELD_BOW = data.BPTTField(bptt_length=bptt_length,
                                   use_vocab=False,
                                   fix_length=fix_len,
                                   include_lengths=False,
                                   pad_token=np.zeros(bow_size),
                                   preprocessing=unpack_bow2seq,
                                   dtype=torch.float64)

        FIELD_TEXT = data.ReversibleField(init_token='<sos>',
                                          eos_token='<eos>',
                                          tokenize=tokenizer,
                                          batch_first=True,
                                          use_vocab=True)
        NESTED_TEXT_FIELD = data.NestedBPTTField(FIELD_TEXT,
                                                 bptt_length=bptt_length,
                                                 use_vocab=False,
                                                 fix_length=fix_len,
                                                 preprocessing=unpack_text,
                                                 include_lengths=True)

        train, valid, test = datasets.TextPointDataSet.splits(
            server,
            time_field=FIELD_TIME,
            text_field=NESTED_TEXT_FIELD,
            bow_field=FIELD_BOW,
            train=train_col,
            validation=val_col,
            test=test_col,
            **kwargs)

        if fix_len == -1:
            max_len = max([train.max_len, valid.max_len, test.max_len])
            FIELD_TIME.fix_length = max_len
            FIELD_BOW.fix_length = max_len
            NESTED_TEXT_FIELD.fix_length = max_len
        self._train_iter, self._valid_iter, self._test_iter = data.BPTTIterator.splits(
            (train, valid, test),
            batch_sizes=(batch_size, batch_size, len(test)),
            sort_key=lambda x: len(x.time),
            sort_within_batch=True,
            repeat=False,
            bptt_len=bptt_length,
            device=device)
        self._bptt_length = bptt_length
        NESTED_TEXT_FIELD.build_vocab(train,
                                      vectors=emb_dim,
                                      vectors_cache=path_to_vectors,
                                      max_size=voc_size,
                                      min_freq=min_freq)
        self.train_vocab = NESTED_TEXT_FIELD.vocab
        self._fix_length = NESTED_TEXT_FIELD.fix_length
        self._bow_size = bow_size
    def __init__(self, device, dtype=torch.float32, **kwargs):
        super().__init__(device, **kwargs)
        kwargs = self.dataset_kwargs
        time_fix_len = kwargs.pop('time_fix_len', None)
        text_fix_len = kwargs.pop('text_fix_len', None)
        bptt_length = kwargs.pop('bptt_len')
        self._t_max = kwargs.pop('t_max')

        server = kwargs.pop('server', 'localhost')
        data_collection_name = kwargs.pop('data_collection')
        db_name = kwargs.pop('db')
        train_col = f'{data_collection_name}_train'
        val_col = f'{data_collection_name}_validation'
        test_col = f'{data_collection_name}_test'

        db = MongoClient(f'mongodb://{server}/')[db_name]
        col = db[train_col]
        if self._t_max is None:
            min_max_values = list(
                col.aggregate([{
                    "$project": {
                        "_id": 0,
                        "time": 1
                    }
                }, {
                    "$unwind": "$time"
                }, {
                    "$group": {
                        "_id": None,
                        "max": {
                            "$max": "$time"
                        },
                        "min": {
                            "$min": "$time"
                        }
                    }
                }, {
                    "$limit": 1
                }]))[0]
            self._t_min = min_max_values['min']
            self._t_max = min_max_values['max']
        # part_scale = partial(min_max_scale, min_value=self.__t_min, max_value=self.__t_max)
        FIELD_TIME = BPTTField(bptt_length=bptt_length,
                               use_vocab=False,
                               include_lengths=True,
                               pad_token=np.array([0., 0., -1.]),
                               preprocessing=partial(delta, t_max=self._t_max),
                               dtype=dtype,
                               fix_length=time_fix_len)

        FIELD_TEXT = data.ReversibleField(init_token='<sos>',
                                          eos_token='<eos>',
                                          unk_token='<unk>',
                                          tokenize=partial(
                                              tokenizer,
                                              punct=self.punctuation),
                                          batch_first=True,
                                          use_vocab=True,
                                          fix_length=text_fix_len)
        NESTED_TEXT_FIELD = data.NestedBPTTField(FIELD_TEXT,
                                                 bptt_length=bptt_length,
                                                 use_vocab=False,
                                                 fix_length=time_fix_len,
                                                 preprocessing=unpack_text,
                                                 include_lengths=True)

        train, valid, test = datasets.TextPointDataSet.splits(
            server,
            db_name,
            time_field=FIELD_TIME,
            text_field=NESTED_TEXT_FIELD,
            train=train_col,
            validation=val_col,
            test=test_col,
            **kwargs)

        if time_fix_len == -1:
            max_len = max([train.max_len, valid.max_len, test.max_len])
            FIELD_TIME.fix_length = max_len
            NESTED_TEXT_FIELD.fix_length = max_len
        self._train_iter, self._valid_iter, self._test_iter = data.BPTTIterator.splits(
            (train, valid, test),
            batch_sizes=(self.batch_size, self.batch_size, len(test)),
            sort_key=lambda x: x.time.shape[0],
            sort_within_batch=True,
            repeat=False,
            bptt_len=bptt_length,
            device=device)
        NESTED_TEXT_FIELD.build_vocab(train,
                                      vectors=self.emb_dim,
                                      vectors_cache=self.path_to_vectors,
                                      max_size=self.voc_size,
                                      min_freq=self.min_freq)
        self._bptt_length = bptt_length
        self.train_vocab = NESTED_TEXT_FIELD.vocab
        self._time_fix_length = NESTED_TEXT_FIELD.fix_length
        self._text_fix_length = FIELD_TEXT.fix_length