def __init__(self, device, **kwargs): path_to_data = kwargs.pop('path_to_data') super().__init__(device, **kwargs) # Defining fields TEXT = data.ReversibleField(init_token='<sos>', eos_token='<eos>', lower=self.lower, tokenize=tokenizer_ptb, include_lengths=True, fix_length=self._fix_length, batch_first=True) train, valid, test = datasets.PennTreebank.splits(TEXT, root=path_to_data) if self.min_len is not None: train.examples = [ x for x in train.examples if len(x.text) >= self.min_len ] valid.examples = [ x for x in valid.examples if len(x.text) >= self.min_len ] test.examples = [ x for x in test.examples if len(x.text) >= self.min_len ] if self.max_len is not None: train.examples = [ x for x in train.examples if len(x.text) <= self.max_len ] valid.examples = [ x for x in valid.examples if len(x.text) <= self.max_len ] test.examples = [ x for x in test.examples if len(x.text) <= self.max_len ] if self._fix_length == -1: TEXT.fix_length = max([train.max_len, valid.max_len, test.max_len]) self._train_iter, self._valid_iter, self._test_iter = BucketIterator.splits( (train, valid, test), batch_sizes=(self.batch_size, self.batch_size, self.batch_size), sort_key=lambda x: len(x.text), sort_within_batch=True, repeat=False, device=self.device) TEXT.build_vocab(train, vectors=self.emb_dim, vectors_cache=self.path_to_vectors, max_size=self.voc_size, min_freq=self.min_freq) self.train_vocab = TEXT.vocab self._fix_length = TEXT.fix_length
def __init__(self, device, **kwargs): server = kwargs.pop('server', 'localhost') db_name = kwargs.pop('db') data_collection_name = kwargs.pop('data_collection') super().__init__(device, **kwargs) train_col = f'{data_collection_name}_train' val_col = f'{data_collection_name}_validation' test_col = f'{data_collection_name}_test' FIELD_TEXT = data.ReversibleField(init_token='<sos>', eos_token='<eos>', unk_token='<unk>', tokenize=partial( tokenizer, punct=self.punctuation), batch_first=True, use_vocab=True, fix_length=self._fix_length, include_lengths=True, lower=self.lower) train, valid, test = datasets.Yelp2019.splits(server, db_name, text_field=FIELD_TEXT, train=train_col, validation=val_col, test=test_col, **self.dataset_kwargs) if self._fix_length == -1: FIELD_TEXT.fix_length = max( [train.max_len, valid.max_len, test.max_len]) self._train_iter, self._valid_iter, self._test_iter = BucketIterator.splits( (train, valid, test), batch_sizes=(self.batch_size, self.batch_size, len(test)), sort_key=lambda x: len(x.text), sort_within_batch=True, repeat=False, device=self.device) FIELD_TEXT.build_vocab(train, vectors=self.emb_dim, vectors_cache=self.path_to_vectors, max_size=self.voc_size, min_freq=self.min_freq) self.train_vocab = FIELD_TEXT.vocab self._fix_length = FIELD_TEXT.fix_length
def __init__(self, **kwargs): batch_size = kwargs.get('batch_size') path_to_data = kwargs.pop('path_to_data') path_to_vectors = kwargs.pop('path_to_vectors') emb_dim = kwargs.pop('emb_dim') voc_size = kwargs.pop('voc_size', None) min_freq = kwargs.pop('min_freq', 1) fix_len = kwargs.pop('fix_len', None) # Defining fields TEXT = data.ReversibleField(init_token='<sos>', eos_token='<eos>', unk_token='UNK', lower=True, tokenize=tokenizer, include_lengths=True, fix_length=fix_len, batch_first=True) train, valid, test = datasets.PennTreebank.splits(TEXT, root=path_to_data) if fix_len == -1: TEXT.fix_length = max([train.max_len, valid.max_len, test.max_len]) self._train_iter, self._valid_iter, self._test_iter = data.BucketIterator.splits( (train, valid, test), batch_sizes=(batch_size, batch_size, len(test)), sort_key=lambda x: len(x.text), sort_within_batch=True, repeat=False) TEXT.build_vocab(train, vectors=emb_dim, vectors_cache=path_to_vectors, max_size=voc_size, min_freq=min_freq) self.train_vocab = TEXT.vocab self.fix_length = TEXT.fix_length
def __init__(self, device, **kwargs): path_to_data = kwargs.pop('path_to_data') super().__init__(device, **kwargs) # Defining fields TEXT = data.ReversibleField(init_token='<sos>', eos_token='<eos>', tokenize=None, lower=self.lower, include_lengths=True, fix_length=self._fix_length, batch_first=True) train, valid, test = datasets.WikiText103.splits(TEXT, root=path_to_data) for dataset in [train, valid, test]: dataset = _preprocess_wiki(dataset, self.min_len, self.max_len) if self._fix_length == -1: TEXT.fix_length = max([train.max_len, valid.max_len, test.max_len]) self._train_iter, self._valid_iter, self._test_iter = BucketIterator.splits( (train, valid, test), batch_sizes=(self.batch_size, self.batch_size, len(test)), sort_key=lambda x: len(x.text), sort_within_batch=True, repeat=False, device=self.device) TEXT.build_vocab(train, vectors=self.emb_dim, vectors_cache=self.path_to_vectors, max_size=self.voc_size, min_freq=self.min_freq) self.train_vocab = TEXT.vocab self._fix_length = TEXT.fix_length
def __init__(self, **kwargs): batch_size = kwargs.pop('batch_size', 32) path_to_vectors = kwargs.pop('path_to_vectors') emb_dim = kwargs.pop('emb_dim') voc_size = kwargs.pop('voc_size', None) min_freq = kwargs.pop('min_freq', 1) fix_len = kwargs.pop('fix_len', None) bptt_length = kwargs.pop('bptt_len', 20) bow_size = kwargs.get('bow_size') server = kwargs.pop('server', 'localhost') data_collection_name = kwargs.pop('data_collection') FIELD_TIME = data.BPTTField(bptt_len=bptt_length, use_vocab=False, include_lengths=True, pad_token=[0, 0, 0], preprocessing=delta) FIELD_BOW = data.BPTTField(bptt_len=bptt_length, use_vocab=False, include_lengths=False, pad_token=csr_matrix((1, bow_size)), preprocessing=unpack_bow2seq, postprocessing=expand_bow_vector, dtype=torch.float32) FIELD_TEXT = data.ReversibleField(init_token='<sos>', eos_token='<eos>', unk_token='UNK', tokenize=tokenizer, batch_first=True, use_vocab=True, is_target=True) NESTED_TEXT_FIELD = NestedField(FIELD_TEXT, use_vocab=False, preprocessing=unpack_text) train_col = f'{data_collection_name}_train' val_col = f'{data_collection_name}_validation' test_col = f'{data_collection_name}_test' train, valid, test = datasets.RatebeerBow2Seq.splits( server, time_field=FIELD_TIME, text_field=NESTED_TEXT_FIELD, bow_field=FIELD_BOW, train=train_col, validation=val_col, test=test_col, **kwargs) if fix_len == -1: max_len = max([train.max_len, valid.max_len, test.max_len]) FIELD_TIME.fix_length = max_len FIELD_TEXT.fix_length = max_len self._train_iter, self._valid_iter, self._test_iter = data.BPTTIterator.splits( (train, valid, test), batch_sizes=(batch_size, batch_size, len(test)), sort_key=lambda x: len(x.time), sort_within_batch=True, repeat=False, bptt_len=bptt_length) self.bptt_length = bptt_length NESTED_TEXT_FIELD.build_vocab(train, vectors=emb_dim, vectors_cache=path_to_vectors, max_size=voc_size, min_freq=min_freq) self.train_vocab = NESTED_TEXT_FIELD.vocab self.fix_length = NESTED_TEXT_FIELD.fix_length
def __init__(self, device, **kwargs): batch_size = kwargs.pop('batch_size') path_to_vectors = kwargs.pop('path_to_vectors') emb_dim = kwargs.pop('emb_dim') voc_size = kwargs.pop('voc_size', None) min_freq = kwargs.pop('min_freq') fix_len = kwargs.pop('fix_len', None) bptt_length = kwargs.pop('bptt_len') bow_size = kwargs.get('bow_size') server = kwargs.pop('server', 'localhost') data_collection_name = kwargs.pop('data_collection') self.__t_max = kwargs.pop('t_max') train_col = f'{data_collection_name}_train' val_col = f'{data_collection_name}_validation' test_col = f'{data_collection_name}_test' db = MongoClient('mongodb://' + server)['hawkes_text'] col = db[train_col] if self.__t_max is None: min_max_values = list( col.aggregate([{ "$project": { "_id": 0, "time": 1 } }, { "$unwind": "$time" }, { "$group": { "_id": None, "max": { "$max": "$time" }, "min": { "$min": "$time" } } }, { "$limit": 1 }]))[0] self.__t_min = min_max_values['min'] self.__t_max = min_max_values['max'] # part_scale = partial(min_max_scale, min_value=self.__t_min, max_value=self.__t_max) FIELD_TIME = data.BPTTField(bptt_length=bptt_length, use_vocab=False, fix_length=fix_len, include_lengths=True, pad_token=[-1.0, -1.0, -1.0], preprocessing=partial(delta, t_max=self.__t_max), dtype=torch.float64) FIELD_BOW = data.BPTTField(bptt_length=bptt_length, use_vocab=False, fix_length=fix_len, include_lengths=False, pad_token=np.zeros(bow_size), preprocessing=unpack_bow2seq, dtype=torch.float64) FIELD_TEXT = data.ReversibleField(init_token='<sos>', eos_token='<eos>', tokenize=tokenizer, batch_first=True, use_vocab=True) NESTED_TEXT_FIELD = data.NestedBPTTField(FIELD_TEXT, bptt_length=bptt_length, use_vocab=False, fix_length=fix_len, preprocessing=unpack_text, include_lengths=True) train, valid, test = datasets.TextPointDataSet.splits( server, time_field=FIELD_TIME, text_field=NESTED_TEXT_FIELD, bow_field=FIELD_BOW, train=train_col, validation=val_col, test=test_col, **kwargs) if fix_len == -1: max_len = max([train.max_len, valid.max_len, test.max_len]) FIELD_TIME.fix_length = max_len FIELD_BOW.fix_length = max_len NESTED_TEXT_FIELD.fix_length = max_len self._train_iter, self._valid_iter, self._test_iter = data.BPTTIterator.splits( (train, valid, test), batch_sizes=(batch_size, batch_size, len(test)), sort_key=lambda x: len(x.time), sort_within_batch=True, repeat=False, bptt_len=bptt_length, device=device) self._bptt_length = bptt_length NESTED_TEXT_FIELD.build_vocab(train, vectors=emb_dim, vectors_cache=path_to_vectors, max_size=voc_size, min_freq=min_freq) self.train_vocab = NESTED_TEXT_FIELD.vocab self._fix_length = NESTED_TEXT_FIELD.fix_length self._bow_size = bow_size
def __init__(self, device, dtype=torch.float32, **kwargs): super().__init__(device, **kwargs) kwargs = self.dataset_kwargs time_fix_len = kwargs.pop('time_fix_len', None) text_fix_len = kwargs.pop('text_fix_len', None) bptt_length = kwargs.pop('bptt_len') self._t_max = kwargs.pop('t_max') server = kwargs.pop('server', 'localhost') data_collection_name = kwargs.pop('data_collection') db_name = kwargs.pop('db') train_col = f'{data_collection_name}_train' val_col = f'{data_collection_name}_validation' test_col = f'{data_collection_name}_test' db = MongoClient(f'mongodb://{server}/')[db_name] col = db[train_col] if self._t_max is None: min_max_values = list( col.aggregate([{ "$project": { "_id": 0, "time": 1 } }, { "$unwind": "$time" }, { "$group": { "_id": None, "max": { "$max": "$time" }, "min": { "$min": "$time" } } }, { "$limit": 1 }]))[0] self._t_min = min_max_values['min'] self._t_max = min_max_values['max'] # part_scale = partial(min_max_scale, min_value=self.__t_min, max_value=self.__t_max) FIELD_TIME = BPTTField(bptt_length=bptt_length, use_vocab=False, include_lengths=True, pad_token=np.array([0., 0., -1.]), preprocessing=partial(delta, t_max=self._t_max), dtype=dtype, fix_length=time_fix_len) FIELD_TEXT = data.ReversibleField(init_token='<sos>', eos_token='<eos>', unk_token='<unk>', tokenize=partial( tokenizer, punct=self.punctuation), batch_first=True, use_vocab=True, fix_length=text_fix_len) NESTED_TEXT_FIELD = data.NestedBPTTField(FIELD_TEXT, bptt_length=bptt_length, use_vocab=False, fix_length=time_fix_len, preprocessing=unpack_text, include_lengths=True) train, valid, test = datasets.TextPointDataSet.splits( server, db_name, time_field=FIELD_TIME, text_field=NESTED_TEXT_FIELD, train=train_col, validation=val_col, test=test_col, **kwargs) if time_fix_len == -1: max_len = max([train.max_len, valid.max_len, test.max_len]) FIELD_TIME.fix_length = max_len NESTED_TEXT_FIELD.fix_length = max_len self._train_iter, self._valid_iter, self._test_iter = data.BPTTIterator.splits( (train, valid, test), batch_sizes=(self.batch_size, self.batch_size, len(test)), sort_key=lambda x: x.time.shape[0], sort_within_batch=True, repeat=False, bptt_len=bptt_length, device=device) NESTED_TEXT_FIELD.build_vocab(train, vectors=self.emb_dim, vectors_cache=self.path_to_vectors, max_size=self.voc_size, min_freq=self.min_freq) self._bptt_length = bptt_length self.train_vocab = NESTED_TEXT_FIELD.vocab self._time_fix_length = NESTED_TEXT_FIELD.fix_length self._text_fix_length = FIELD_TEXT.fix_length