Esempio n. 1
0
    def __init__(self, dataset, batch_size, sort_key=None, device=None,
                 batch_size_fn=None,
                 repeat=False, shuffle=False, sort=False,
                 sort_within_batch=False):
        self.batch_size, self.dataset = batch_size, dataset
        self.batch_size_fn = batch_size_fn
        self.batch_indices = None
        self.batch_sizes = None
        self.iterations = 0
        self.repeat = repeat
        self.shuffle = shuffle
        self.sort = sort
        self.sort_within_batch = sort_within_batch

        if sort_key is None:
            self.sort_key = dataset.sort_key
        else:
            self.sort_key = sort_key

        if type(device) == int:
            logger.warning("The `device` argument should be set by using `torch.device`" +
                           " or passing a string as an argument. This behavior will be" +
                           " deprecated soon and currently defaults to cpu.")
            device = None
        self.device = device
        self.random_shuffler = RandomShuffler()

        # For state loading/saving only
        self._iterations_this_epoch = 0
        self._random_state_this_epoch = None
        self._restored_from_state = False
Esempio n. 2
0
 def __init__(self, train_shards, fields, device, opt):
     self.index = -1
     self.iterables = []
     self.weights = []
     for shard, weight in zip(train_shards, opt.data_weights):
         if weight > 0:
             self.iterables.append(
                 build_dataset_iter(shard, fields, opt, multi=True))
             self.weights.append(weight)
     self.init_iterators = True
     # self.weights = opt.data_weights
     self.batch_size = opt.batch_size
     self.batch_size_fn = max_tok_len \
         if opt.batch_type == "tokens" else None
     if opt.batch_size_multiple is not None:
         self.batch_size_multiple = opt.batch_size_multiple
     else:
         self.batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1
     self.device = device
     # Temporarily load one shard to retrieve sort_key for data_type
     temp_dataset = torch.load(self.iterables[0]._paths[0])
     self.sort_key = temp_dataset.sort_key
     self.random_shuffler = RandomShuffler()
     self.pool_factor = opt.pool_factor
     del temp_dataset
Esempio n. 3
0
    def set_rng_state(self, state):
        """
        Restore the RNG state

        :param state: state of RNG to be restored
        """
        self._random = RandomShuffler(state)
        # Generate the batches by following the state
        self.reset()
Esempio n. 4
0
    def split(self,
              split_ratio=0.7,
              stratified=False,
              strata_field="label",
              random_state=None):
        """Create train-test(-valid?) splits from the instance's examples.

        Arguments:
            split_ratio (float or List of floats): a number [0, 1] denoting the amount
                of data to be used for the training split (rest is used for test),
                or a list of numbers denoting the relative sizes of train, test and valid
                splits respectively. If the relative size for valid is missing, only the
                train-test split is returned. Default is 0.7 (for the train set).
            stratified (bool): whether the sampling should be stratified.
                Default is False.
            strata_field (str): name of the examples Field stratified over.
                Default is 'label' for the conventional label field.
            random_state (tuple): the random seed used for shuffling.
                A return value of `random.getstate()`.

        Returns:
            Tuple[Dataset]: Datasets for train, validation, and
            test splits in that order, if the splits are provided.
        """
        train_ratio, test_ratio, val_ratio = check_split_ratio(split_ratio)

        # For the permutations
        rnd = RandomShuffler(random_state)
        if not stratified:
            train_data, test_data, val_data = rationed_split(
                self.examples, train_ratio, test_ratio, val_ratio, rnd)
        else:
            if strata_field not in self.fields:
                raise ValueError(
                    "Invalid field name for strata_field {}".format(
                        strata_field))
            strata = stratify(self.examples, strata_field)
            train_data, test_data, val_data = [], [], []
            for group in strata:
                # Stratify each group and add together the indices.
                group_train, group_test, group_val = rationed_split(
                    group, train_ratio, test_ratio, val_ratio, rnd)
                train_data += group_train
                test_data += group_test
                val_data += group_val

        splits = tuple(
            Dataset(d, self.fields) for d in (train_data, val_data, test_data)
            if d)

        # In case the parent sort key isn't none
        if self.sort_key:
            for subset in splits:
                subset.sort_key = self.sort_key
        return splits
Esempio n. 5
0
 def _postprocess(self, device, opt):
     self.init_iterators = True
     self.weights = opt.data_weights
     self.batch_size = opt.batch_size
     self.batch_size_fn = max_tok_len \
         if opt.batch_type == "tokens" else None
     self.batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1
     self.device = device
     # Temporarily load one shard to retrieve sort_key for data_type
     temp_dataset = torch.load(self.iterables[0]._paths[0])
     self.sort_key = temp_dataset.sort_key
     self.random_shuffler = RandomShuffler()
     self.pool_factor = opt.pool_factor
     del temp_dataset
Esempio n. 6
0
    def __init__(
        self,
        dataset,
        batch_size,
        sort_key=None,
        device=None,
        batch_size_fn=None,
        train=True,
        repeat=False,
        shuffle=None,
        sort=None,
        sort_within_batch=None,
    ):
        self.batch_size, self.train, self.dataset = batch_size, train, dataset
        self.batch_size_fn = batch_size_fn
        self.iterations = 0
        self.repeat = repeat
        self.shuffle = train if shuffle is None else shuffle
        self.sort = not train if sort is None else sort

        if sort_within_batch is None:
            self.sort_within_batch = self.sort
        else:
            self.sort_within_batch = sort_within_batch
        if sort_key is None:
            self.sort_key = dataset.sort_key
        else:
            self.sort_key = sort_key

        if isinstance(device, int):
            logger.warning(
                "The `device` argument should be set by using `torch.device`" +
                " or passing a string as an argument. This behavior will be" +
                " deprecated soon and currently defaults to cpu.")
            device = None

        if device is None:
            device = torch.device("cpu")
        elif isinstance(device, str):
            device = torch.device(device)

        self.device = device
        self.random_shuffler = RandomShuffler()

        # For state loading/saving only
        self._iterations_this_epoch = 0
        self._random_state_this_epoch = None
        self._restored_from_state = False
Esempio n. 7
0
 def __init__(self, iterables, device, opt):
     self.index = -1
     self.iterators = [iter(iterable) for iterable in iterables]
     self.iterables = iterables
     self.weights = opt.data_weights
     self.batch_size = opt.batch_size
     self.batch_size_fn = max_tok_len \
         if opt.batch_type == "tokens" else None
     self.batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1
     self.device = "cuda" if device >= 0 else "cpu"
     # Temporarily load one shard to retrieve sort_key for data_type
     temp_dataset = torch.load(self.iterables[0]._paths[0])
     self.sort_key = temp_dataset.sort_key
     self.random_shuffler = RandomShuffler()
     self.pool_factor = opt.pool_factor
     del temp_dataset
Esempio n. 8
0
    def __init__(self,
                 dataset,
                 batch_size,
                 sort_key=None,
                 device=None,
                 batch_size_fn=None,
                 train=True,
                 repeat=False,
                 shuffle=None,
                 sort=None,
                 sort_within_batch=None,
                 poolnum=None):

        self.fields = dict(dataset.fields)
        self.batch_size, self.train, self.dataset = batch_size, train, dataset
        self.poolnum = poolnum
        self.batch_size_fn = batch_size_fn
        self.iterations = 0
        self.repeat = repeat
        self.shuffle = train if shuffle is None else shuffle
        self.sort = not train if sort is None else sort

        if sort_within_batch is None:
            self.sort_within_batch = self.sort
        else:
            self.sort_within_batch = sort_within_batch

        if sort_key is None:
            self.sort_key = lambda ex: interleave_keys(len(ex.src), len(ex.tgt)
                                                       )
        else:
            self.sort_key = sort_key

        if type(device) == int:
            logger.warning(
                "The `device` argument should be set by using `torch.device`" +
                " or passing a string as an argument. This behavior will be" +
                " deprecated soon and currently defaults to cpu.")
            device = None
        self.device = device
        self.random_shuffler = RandomShuffler()

        # For state loading/saving only
        self._iterations_this_epoch = 0
        self._random_state_this_epoch = None
        self._restored_from_state = False
Esempio n. 9
0
 def __init__(self, src_types, train_shards, fields, device, opt):
     self.index = -1
     self.iterables = []
     for shard in train_shards:
         self.iterables.append(
             MultiSourceInputter.build_dataset_iter(src_types,
                                                    shard,
                                                    fields,
                                                    opt,
                                                    multi=True))
     self.init_iterators = True
     self.weights = opt.data_weights
     self.batch_size = opt.batch_size
     self.batch_size_fn = MultiSourceInputter.max_tok_len \
         if opt.batch_type == "tokens" else None
     self.batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1
     self.device = device
     # Temporarily load one shard to retrieve sort_key for data_type
     temp_dataset = torch.load(self.iterables[0]._paths[0])
     self.sort_key = temp_dataset.sort_key
     self.random_shuffler = RandomShuffler()
     self.pool_factor = opt.pool_factor
     del temp_dataset
Esempio n. 10
0
    def __init__(self,
                 dataset: str,
                 problem_field: ProblemTextField,
                 op_gen_field: OpEquationField,
                 expr_gen_field: ExpressionEquationField,
                 expr_ptr_field: ExpressionEquationField,
                 token_batch_size: int = 4096,
                 testing_purpose: bool = False):
        """
        Instantiate batch iterator

        :param str dataset: Path of JSON with lines file to be loaded.
        :param ProblemTextField problem_field: Text field for encoder
        :param OpEquationField op_gen_field: OP-token equation field for decoder
        :param ExpressionEquationField expr_gen_field: Expression-token equation field for decoder (no pointer)
        :param ExpressionEquationField expr_ptr_field: Expression-token equation field for decoder (pointer)
        :param int token_batch_size: Maximum bound for batch size in terms of tokens.
        :param bool testing_purpose:
            True if this dataset is for testing. Otherwise, we will randomly shuffle the dataset.
        """
        # Define fields
        self.problem_field = problem_field
        self.op_gen_field = op_gen_field
        self.expr_gen_field = expr_gen_field
        self.expr_ptr_field = expr_ptr_field

        # Store the batch size
        self._batch_size = token_batch_size
        # Store whether this dataset is for testing or not.
        self._testing_purpose = testing_purpose

        # Storage for list of shuffled batches
        self._batches = None
        # Iterator for batches
        self._iterator = None
        # Random shuffler
        self._random = RandomShuffler() if not testing_purpose else None

        # Read the dataset.
        cached_path = Path(dataset + '.cached')
        if cached_path.exists():
            # If cached version is available, load the dataset from it.
            cache = load_data(cached_path)
            self._dataset = cache['dataset']
            vocab_cache = cache['vocab']

            # Restore vocabulary from the dataset.
            if self.op_gen_field.has_empty_vocab:
                self.op_gen_field.token_vocab = vocab_cache['token']
            if self.expr_gen_field.has_empty_vocab:
                self.expr_gen_field.operator_word_vocab = vocab_cache['func']
                self.expr_gen_field.constant_word_vocab = vocab_cache['arg']
            if self.expr_ptr_field.has_empty_vocab:
                self.expr_ptr_field.operator_word_vocab = vocab_cache['func']
                self.expr_ptr_field.constant_word_vocab = vocab_cache['const']
        else:
            # Otherwise, compute preprocessed result and cache it in the disk
            # First, read the JSON with lines file.
            _dataset = []
            _items_for_vocab = []
            with Path(dataset).open('r+t', encoding='UTF-8') as fp:
                for line in fp.readlines():
                    line = line.strip()
                    if not line:
                        continue

                    item = json.loads(line)

                    # We only need 'text', 'expr', 'id', and 'answer'
                    _dataset.append((item['text'], item['expr'], item['id'],
                                     item['answer']))
                    # Separately gather equations to build vocab.
                    _items_for_vocab.append(item['expr'])

            # Build vocab if it is empty
            if self.op_gen_field.has_empty_vocab:
                self.op_gen_field.build_vocab(_items_for_vocab)
            if self.expr_gen_field.has_empty_vocab:
                self.expr_gen_field.build_vocab(_items_for_vocab)
            if self.expr_ptr_field.has_empty_vocab:
                self.expr_ptr_field.build_vocab(_items_for_vocab)

            # Run preprocessing
            self._dataset = [
                self._tokenize_equation(item) for item in _dataset
            ]

            # Cache dataset and vocabulary.
            save_data(
                {
                    'dataset': self._dataset,
                    'vocab': {
                        'token': self.op_gen_field.token_vocab,
                        'func': self.expr_gen_field.operator_word_vocab,
                        'arg': self.expr_gen_field.constant_word_vocab,
                        'const': self.expr_ptr_field.constant_word_vocab,
                    }
                }, cached_path)

        # Compute the number of examples
        self._examples = len(self._dataset)
        # Generate the batches.
        self.reset()